diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..23f2e229ff2fca94b6d015ba18e41ecc1cdfa9af --- /dev/null +++ b/.gitignore @@ -0,0 +1,155 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + + diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..0704d550acddc5efeb7a51d156f413a1f791b857 --- /dev/null +++ b/app.py @@ -0,0 +1,205 @@ +# Not ready to use yet +import argparse +import numpy as np +import gradio as gr +from omegaconf import OmegaConf +import torch +from PIL import Image +import PIL +from pipelines import TwoStagePipeline +from huggingface_hub import hf_hub_download +import os +import rembg +from typing import Any +import json +import os +import json +import argparse + +from model import CRM +from inference import generate3d + +pipeline = None +rembg_session = rembg.new_session() + + +def check_input_image(input_image): + if input_image is None: + raise gr.Error("No image uploaded!") + + +def remove_background( + image: PIL.Image.Image, + rembg_session: Any = None, + force: bool = False, + **rembg_kwargs, +) -> PIL.Image.Image: + do_remove = True + if image.mode == "RGBA" and image.getextrema()[3][0] < 255: + # explain why current do not rm bg + print("alhpa channl not enpty, skip remove background, using alpha channel as mask") + background = Image.new("RGBA", image.size, (0, 0, 0, 0)) + image = Image.alpha_composite(background, image) + do_remove = False + do_remove = do_remove or force + if do_remove: + image = rembg.remove(image, session=rembg_session, **rembg_kwargs) + return image + +def do_resize_content(original_image: Image, scale_rate): + # resize image content wile retain the original image size + if scale_rate != 1: + # Calculate the new size after rescaling + new_size = tuple(int(dim * scale_rate) for dim in original_image.size) + # Resize the image while maintaining the aspect ratio + resized_image = original_image.resize(new_size) + # Create a new image with the original size and black background + padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0)) + paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2) + padded_image.paste(resized_image, paste_position) + return padded_image + else: + return original_image + +def add_background(image, bg_color=(255, 255, 255)): + # given an RGBA image, alpha channel is used as mask to add background color + background = Image.new("RGBA", image.size, bg_color) + return Image.alpha_composite(background, image) + + +def preprocess_image(input_image, do_remove_background, force_remove, foreground_ratio, backgroud_color): + """ + input image is a pil image in RGBA, return RGB image + """ + if do_remove_background: + image = remove_background(input_image, rembg_session, force_remove) + image = do_resize_content(image, foreground_ratio) + image = add_background(image, backgroud_color) + return image.convert("RGB") + + +def gen_image(input_image, seed, scale, step): + global pipeline, model, args + pipeline.set_seed(seed) + rt_dict = pipeline(input_image, scale=scale, step=step) + stage1_images = rt_dict["stage1_images"] + stage2_images = rt_dict["stage2_images"] + np_imgs = np.concatenate(stage1_images, 1) + np_xyzs = np.concatenate(stage2_images, 1) + + glb_path, obj_path = generate3d(model, np_imgs, np_xyzs, args.device) + return Image.fromarray(np_imgs), Image.fromarray(np_xyzs), glb_path, obj_path + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--stage1_config", + type=str, + default="configs/nf7_v3_SNR_rd_size_stroke.yaml", + help="config for stage1", +) +parser.add_argument( + "--stage2_config", + type=str, + default="configs/stage2-v2-snr.yaml", + help="config for stage2", +) + +parser.add_argument("--device", type=str, default="cuda") +args = parser.parse_args() + +crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth") +specs = json.load(open("configs/specs_objaverse_total.json")) +model = CRM(specs).to(args.device) +model.load_state_dict(torch.load(crm_path, map_location = args.device), strict=False) + +stage1_config = OmegaConf.load(args.stage1_config).config +stage2_config = OmegaConf.load(args.stage2_config).config +stage2_sampler_config = stage2_config.sampler +stage1_sampler_config = stage1_config.sampler + +stage1_model_config = stage1_config.models +stage2_model_config = stage2_config.models + +xyz_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="ccm-diffusion.pth") +pixel_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth") +stage1_model_config.resume = pixel_path +stage2_model_config.resume = xyz_path + +pipeline = TwoStagePipeline( + stage1_model_config, + stage2_model_config, + stage1_sampler_config, + stage2_sampler_config, + device=args.device, + dtype=torch.float16 +) + +with gr.Blocks() as demo: + gr.Markdown("# CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model") + with gr.Row(): + with gr.Column(): + with gr.Row(): + image_input = gr.Image( + label="Image input", + image_mode="RGBA", + sources="upload", + type="pil", + ) + processed_image = gr.Image(label="Processed Image", interactive=False, type="pil", image_mode="RGB") + with gr.Row(): + with gr.Column(): + with gr.Row(): + do_remove_background = gr.Checkbox(label="Remove Background", value=True) + force_remove = gr.Checkbox(label="Force Remove", value=False) + back_groud_color = gr.ColorPicker(label="Background Color", value="#7F7F7F", interactive=False) + foreground_ratio = gr.Slider( + label="Foreground Ratio", + minimum=0.5, + maximum=1.0, + value=1.0, + step=0.05, + ) + + with gr.Column(): + seed = gr.Number(value=1234, label="seed", precision=0) + guidance_scale = gr.Number(value=5.5, minimum=0, maximum=20, label="guidance_scale") + step = gr.Number(value=50, minimum=1, maximum=100, label="sample steps", precision=0) + text_button = gr.Button("Generate Images") + with gr.Column(): + image_output = gr.Image(interactive=False, label="Output RGB image") + xyz_ouput = gr.Image(interactive=False, label="Output CCM image") + + output_model = gr.Model3D( + label="Output GLB", + interactive=False, + ) + output_obj = gr.File(interactive=False, label="Output OBJ") + + inputs = [ + processed_image, + seed, + guidance_scale, + step, + ] + outputs = [ + image_output, + xyz_ouput, + output_model, + output_obj, + ] + gr.Examples( + examples=[os.path.join("examples", i) for i in os.listdir("examples")], + inputs=[image_input], + ) + + text_button.click(fn=check_input_image, inputs=[image_input]).success( + fn=preprocess_image, + inputs=[image_input, do_remove_background, force_remove, foreground_ratio, back_groud_color], + outputs=[processed_image], + ).success( + fn=gen_image, + inputs=inputs, + outputs=outputs, + ) + demo.queue().launch() diff --git a/configs/nf7_v3_SNR_rd_size_stroke.yaml b/configs/nf7_v3_SNR_rd_size_stroke.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f9f81a412c773d0757819a4323384c33b540da96 --- /dev/null +++ b/configs/nf7_v3_SNR_rd_size_stroke.yaml @@ -0,0 +1,21 @@ +config: +# others + seed: 1234 + num_frames: 7 + mode: pixel + offset_noise: true +# model related + models: + config: imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml + resume: models/pixel.pth +# sampler related + sampler: + target: libs.sample.ImageDreamDiffusion + params: + mode: pixel + num_frames: 7 + camera_views: [1, 2, 3, 4, 5, 0, 0] + ref_position: 6 + random_background: false + offset_noise: true + resize_rate: 1.0 \ No newline at end of file diff --git a/configs/specs_objaverse_total.json b/configs/specs_objaverse_total.json new file mode 100644 index 0000000000000000000000000000000000000000..d1541e3c7ac3c93264065567eb11a11eddccf359 --- /dev/null +++ b/configs/specs_objaverse_total.json @@ -0,0 +1,57 @@ +{ + "Input": { + "img_num": 16, + "class": "all", + "camera_angle_num": 8, + "tet_grid_size": 80, + "validate_num": 16, + "scale": 0.95, + "radius": 3, + "resolution": [256, 256] + }, + + "Pretrain": { + "mode": null, + "sdf_threshold": 0.1, + "sdf_scale": 10, + "batch_infer": false, + "lr": 1e-4, + "radius": 0.5 + }, + + "Train": { + "mode": "rnd", + "num_epochs": 500, + "grad_acc": 1, + "warm_up": 0, + "decay": 0.000, + "learning_rate": { + "init": 1e-4, + "sdf_decay": 1, + "rgb_decay": 1 + }, + "batch_size": 4, + "eva_iter": 80, + "eva_all_epoch": 10, + "tex_sup_mode": "blender", + "exp_uv_mesh": false, + "doub": false, + "random_bg": false, + "shift": 0, + "aug_shift": 0, + "geo_type": "flex" + }, + + "ArchSpecs": { + "unet_type": "diffusers", + "use_3D_aware": false, + "fea_concat": false, + "mlp_bias": true + }, + + "DecoderSpecs": { + "c_dim": 32, + "plane_resolution": 256 + } +} + diff --git a/configs/stage2-v2-snr.yaml b/configs/stage2-v2-snr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6dafe186c51fe0aa70739a1bbe983e2a04f42107 --- /dev/null +++ b/configs/stage2-v2-snr.yaml @@ -0,0 +1,25 @@ +config: +# others + seed: 1234 + num_frames: 6 + mode: pixel + offset_noise: true + gd_type: xyz +# model related + models: + config: imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml + resume: models/xyz.pth + +# eval related + sampler: + target: libs.sample.ImageDreamDiffusionStage2 + params: + mode: pixel + num_frames: 6 + camera_views: [1, 2, 3, 4, 5, 0] + ref_position: null + random_background: false + offset_noise: true + resize_rate: 1.0 + + diff --git "a/examples/3D\345\215\241\351\200\232\347\213\227.webp" "b/examples/3D\345\215\241\351\200\232\347\213\227.webp" new file mode 100644 index 0000000000000000000000000000000000000000..108eeb2123a79527d841b8be18b319adacf140b8 Binary files /dev/null and "b/examples/3D\345\215\241\351\200\232\347\213\227.webp" differ diff --git a/examples/astronaut.webp b/examples/astronaut.webp new file mode 100644 index 0000000000000000000000000000000000000000..d08b2f9450a5511309a97e7e9718a963e13f45e7 Binary files /dev/null and b/examples/astronaut.webp differ diff --git a/examples/bulldog.webp b/examples/bulldog.webp new file mode 100644 index 0000000000000000000000000000000000000000..262d06a2acdaedfd4f537efa09b75dbf8d55aed9 Binary files /dev/null and b/examples/bulldog.webp differ diff --git a/examples/ghost-eating-burger.webp b/examples/ghost-eating-burger.webp new file mode 100644 index 0000000000000000000000000000000000000000..3693930161d65c639c703c4d781c5cec4153464c Binary files /dev/null and b/examples/ghost-eating-burger.webp differ diff --git a/examples/kunkun.webp b/examples/kunkun.webp new file mode 100644 index 0000000000000000000000000000000000000000..1c1486a4b78b8a0c170afe32dc67e22caee50343 Binary files /dev/null and b/examples/kunkun.webp differ diff --git "a/examples/\344\270\207\345\234\243\345\215\227\347\223\234.webp" "b/examples/\344\270\207\345\234\243\345\215\227\347\223\234.webp" new file mode 100644 index 0000000000000000000000000000000000000000..c2fcf0d1ce5931591ac602fbe54f3d866bcba195 Binary files /dev/null and "b/examples/\344\270\207\345\234\243\345\215\227\347\223\234.webp" differ diff --git "a/examples/\344\272\214\346\254\241\345\205\203\345\245\263\345\255\251.webp" "b/examples/\344\272\214\346\254\241\345\205\203\345\245\263\345\255\251.webp" new file mode 100644 index 0000000000000000000000000000000000000000..944fa4260c4637923744b277ce3c4c45d01e1bad Binary files /dev/null and "b/examples/\344\272\214\346\254\241\345\205\203\345\245\263\345\255\251.webp" differ diff --git "a/examples/\344\272\214\346\254\241\345\205\203\345\245\263\345\255\2512.webp" "b/examples/\344\272\214\346\254\241\345\205\203\345\245\263\345\255\2512.webp" new file mode 100644 index 0000000000000000000000000000000000000000..22b490cfd559a0a3dcfd879607b65590f633dbe3 Binary files /dev/null and "b/examples/\344\272\214\346\254\241\345\205\203\345\245\263\345\255\2512.webp" differ diff --git "a/examples/\344\272\272\347\211\251\351\252\221\351\251\254.webp" "b/examples/\344\272\272\347\211\251\351\252\221\351\251\254.webp" new file mode 100644 index 0000000000000000000000000000000000000000..5033e5b8fd4e40a5628c2fad08ba8e6ed26ee121 Binary files /dev/null and "b/examples/\344\272\272\347\211\251\351\252\221\351\251\254.webp" differ diff --git "a/examples/\345\207\257.webp" "b/examples/\345\207\257.webp" new file mode 100644 index 0000000000000000000000000000000000000000..a3eb93cff989e81fe8305dc5ebb54d4ad0762d0b Binary files /dev/null and "b/examples/\345\207\257.webp" differ diff --git "a/examples/\345\210\235\351\237\263\346\234\252\346\235\245\347\216\251\345\201\266.webp" "b/examples/\345\210\235\351\237\263\346\234\252\346\235\245\347\216\251\345\201\266.webp" new file mode 100644 index 0000000000000000000000000000000000000000..df6151a12b81501acc3004ede832b239e649b8b8 Binary files /dev/null and "b/examples/\345\210\235\351\237\263\346\234\252\346\235\245\347\216\251\345\201\266.webp" differ diff --git "a/examples/\345\215\241\351\200\232\346\201\220\351\276\231.webp" "b/examples/\345\215\241\351\200\232\346\201\220\351\276\231.webp" new file mode 100644 index 0000000000000000000000000000000000000000..dceaa89a61d4fed20910b6ea92d5fb4318059ac0 Binary files /dev/null and "b/examples/\345\215\241\351\200\232\346\201\220\351\276\231.webp" differ diff --git "a/examples/\345\215\241\351\200\232\346\211\213\346\236\252\346\210\252\345\233\276.webp" "b/examples/\345\215\241\351\200\232\346\211\213\346\236\252\346\210\252\345\233\276.webp" new file mode 100644 index 0000000000000000000000000000000000000000..a2cf9435498f0473b2fd9ed846e1301f0587e6e2 Binary files /dev/null and "b/examples/\345\215\241\351\200\232\346\211\213\346\236\252\346\210\252\345\233\276.webp" differ diff --git "a/examples/\345\215\241\351\200\232\347\214\253.webp" "b/examples/\345\215\241\351\200\232\347\214\253.webp" new file mode 100644 index 0000000000000000000000000000000000000000..85bcb9c8e8b87354f0512c6a60f07d800a247635 Binary files /dev/null and "b/examples/\345\215\241\351\200\232\347\214\253.webp" differ diff --git "a/examples/\345\215\241\351\200\232\350\230\221\350\217\207\345\245\227\350\243\205.webp" "b/examples/\345\215\241\351\200\232\350\230\221\350\217\207\345\245\227\350\243\205.webp" new file mode 100644 index 0000000000000000000000000000000000000000..e8db6f69fae78fdf866d92a469adcab5c54f3ab8 Binary files /dev/null and "b/examples/\345\215\241\351\200\232\350\230\221\350\217\207\345\245\227\350\243\205.webp" differ diff --git "a/examples/\345\217\257\347\210\261\347\216\204\347\255\226.webp" "b/examples/\345\217\257\347\210\261\347\216\204\347\255\226.webp" new file mode 100644 index 0000000000000000000000000000000000000000..0516656cc03a56fca7175524c32712acad1f9aef Binary files /dev/null and "b/examples/\345\217\257\347\210\261\347\216\204\347\255\226.webp" differ diff --git "a/examples/\345\244\247\345\244\264\346\263\241\346\263\241\351\251\254\347\211\271.webp" "b/examples/\345\244\247\345\244\264\346\263\241\346\263\241\351\251\254\347\211\271.webp" new file mode 100644 index 0000000000000000000000000000000000000000..a1d5d42c97772ccaddaba5d9ffe15481c811b622 Binary files /dev/null and "b/examples/\345\244\247\345\244\264\346\263\241\346\263\241\351\251\254\347\211\271.webp" differ diff --git "a/examples/\345\275\251\350\211\262\350\230\221\350\217\207.webp" "b/examples/\345\275\251\350\211\262\350\230\221\350\217\207.webp" new file mode 100644 index 0000000000000000000000000000000000000000..a8b7f4af2f78480763317ab3c72c6c00a9240451 Binary files /dev/null and "b/examples/\345\275\251\350\211\262\350\230\221\350\217\207.webp" differ diff --git "a/examples/\345\275\251\350\211\262\350\230\221\350\217\2072.webp" "b/examples/\345\275\251\350\211\262\350\230\221\350\217\2072.webp" new file mode 100644 index 0000000000000000000000000000000000000000..d8d0ad21a29aa37b89d503fa8efc26cdc6dfb6a7 Binary files /dev/null and "b/examples/\345\275\251\350\211\262\350\230\221\350\217\2072.webp" differ diff --git "a/examples/\346\201\220\351\276\231\345\245\227\350\243\205.webp" "b/examples/\346\201\220\351\276\231\345\245\227\350\243\205.webp" new file mode 100644 index 0000000000000000000000000000000000000000..572b9316371295d5a102a3ab07c121a7b0945a2f Binary files /dev/null and "b/examples/\346\201\220\351\276\231\345\245\227\350\243\205.webp" differ diff --git "a/examples/\346\211\213\345\212\236.webp" "b/examples/\346\211\213\345\212\236.webp" new file mode 100644 index 0000000000000000000000000000000000000000..e6adc6e9e50bf8f68903050e3e5dfab7f4514f8c Binary files /dev/null and "b/examples/\346\211\213\345\212\236.webp" differ diff --git "a/examples/\346\234\272\346\242\260\347\213\227\350\243\201\345\210\207.webp" "b/examples/\346\234\272\346\242\260\347\213\227\350\243\201\345\210\207.webp" new file mode 100644 index 0000000000000000000000000000000000000000..a82770f3bf19d0b78a1029532f4f212a58d6c885 Binary files /dev/null and "b/examples/\346\234\272\346\242\260\347\213\227\350\243\201\345\210\207.webp" differ diff --git "a/examples/\346\236\227\345\205\213.webp" "b/examples/\346\236\227\345\205\213.webp" new file mode 100644 index 0000000000000000000000000000000000000000..ec9943125b0b5f17065c47b7d6e3d7bcb84fbe1b Binary files /dev/null and "b/examples/\346\236\227\345\205\213.webp" differ diff --git "a/examples/\346\244\215\347\211\2511.webp" "b/examples/\346\244\215\347\211\2511.webp" new file mode 100644 index 0000000000000000000000000000000000000000..7c61b204ad39d131f6f33a512e3496e04325c766 Binary files /dev/null and "b/examples/\346\244\215\347\211\2511.webp" differ diff --git "a/examples/\346\257\233\347\272\277\350\241\243.webp" "b/examples/\346\257\233\347\272\277\350\241\243.webp" new file mode 100644 index 0000000000000000000000000000000000000000..10d14d92b9d080992a479c538693e12c7a660caf Binary files /dev/null and "b/examples/\346\257\233\347\272\277\350\241\243.webp" differ diff --git "a/examples/\346\265\267\351\276\237.webp" "b/examples/\346\265\267\351\276\237.webp" new file mode 100644 index 0000000000000000000000000000000000000000..70f811b2432b549b36360e802c43c8d1953f5480 Binary files /dev/null and "b/examples/\346\265\267\351\276\237.webp" differ diff --git "a/examples/\347\214\253\344\272\272.webp" "b/examples/\347\214\253\344\272\272.webp" new file mode 100644 index 0000000000000000000000000000000000000000..b3b5262916e7373ffe1c9ba0df13d0ec5bdc0d1b Binary files /dev/null and "b/examples/\347\214\253\344\272\272.webp" differ diff --git "a/examples/\347\214\253\345\244\264\351\271\260.webp" "b/examples/\347\214\253\345\244\264\351\271\260.webp" new file mode 100644 index 0000000000000000000000000000000000000000..74a85db2dda03f8ef20ea16133c594aee105254a Binary files /dev/null and "b/examples/\347\214\253\345\244\264\351\271\260.webp" differ diff --git "a/examples/\347\216\251\345\205\267\345\205\224.webp" "b/examples/\347\216\251\345\205\267\345\205\224.webp" new file mode 100644 index 0000000000000000000000000000000000000000..d8803f7e261a4fb12ce7d30c9fdba931f9374fc5 Binary files /dev/null and "b/examples/\347\216\251\345\205\267\345\205\224.webp" differ diff --git "a/examples/\347\216\251\345\205\267\347\206\212.webp" "b/examples/\347\216\251\345\205\267\347\206\212.webp" new file mode 100644 index 0000000000000000000000000000000000000000..2dcbf8ea2bff0088d0d5e2daf5dce4381531309e Binary files /dev/null and "b/examples/\347\216\251\345\205\267\347\206\212.webp" differ diff --git "a/examples/\347\216\251\345\205\267\347\214\252.webp" "b/examples/\347\216\251\345\205\267\347\214\252.webp" new file mode 100644 index 0000000000000000000000000000000000000000..10dd7573c542bf7f0ff16f100b0e6e834be7daf6 Binary files /dev/null and "b/examples/\347\216\251\345\205\267\347\214\252.webp" differ diff --git "a/examples/\347\216\253\347\221\260.webp" "b/examples/\347\216\253\347\221\260.webp" new file mode 100644 index 0000000000000000000000000000000000000000..8842a5c63906251e0efa01580006511a827a0b55 Binary files /dev/null and "b/examples/\347\216\253\347\221\260.webp" differ diff --git "a/examples/\347\232\256\345\215\241\344\270\230.webp" "b/examples/\347\232\256\345\215\241\344\270\230.webp" new file mode 100644 index 0000000000000000000000000000000000000000..22bd023feddd3242d70f910c1fbba013e7571909 Binary files /dev/null and "b/examples/\347\232\256\345\215\241\344\270\230.webp" differ diff --git "a/examples/\347\232\256\351\236\213.webp" "b/examples/\347\232\256\351\236\213.webp" new file mode 100644 index 0000000000000000000000000000000000000000..cefc66b1c1b4b4df89e1ba05c0bc93c2fd1300ca Binary files /dev/null and "b/examples/\347\232\256\351\236\213.webp" differ diff --git "a/examples/\347\237\263\345\244\264.webp" "b/examples/\347\237\263\345\244\264.webp" new file mode 100644 index 0000000000000000000000000000000000000000..33f076e861baaca430dd9cf0d5e539f8d51188fd Binary files /dev/null and "b/examples/\347\237\263\345\244\264.webp" differ diff --git "a/examples/\347\237\263\345\244\264\345\223\206\345\225\246A\346\242\246.webp" "b/examples/\347\237\263\345\244\264\345\223\206\345\225\246A\346\242\246.webp" new file mode 100644 index 0000000000000000000000000000000000000000..87689487f7b895508da3a651f3becf208fa40b2c Binary files /dev/null and "b/examples/\347\237\263\345\244\264\345\223\206\345\225\246A\346\242\246.webp" differ diff --git "a/examples/\347\272\242\347\216\251\345\205\267\347\214\252.webp" "b/examples/\347\272\242\347\216\251\345\205\267\347\214\252.webp" new file mode 100644 index 0000000000000000000000000000000000000000..38be97f8c0f8d44f84179151ed8108dddafe0f71 Binary files /dev/null and "b/examples/\347\272\242\347\216\251\345\205\267\347\214\252.webp" differ diff --git "a/examples/\347\277\205\350\206\200\351\201\223\345\205\267.webp" "b/examples/\347\277\205\350\206\200\351\201\223\345\205\267.webp" new file mode 100644 index 0000000000000000000000000000000000000000..310e3493df2f9d7069da0a0305fe144aced15f82 Binary files /dev/null and "b/examples/\347\277\205\350\206\200\351\201\223\345\205\267.webp" differ diff --git "a/examples/\350\215\211\347\263\273\347\262\276\347\201\265.webp" "b/examples/\350\215\211\347\263\273\347\262\276\347\201\265.webp" new file mode 100644 index 0000000000000000000000000000000000000000..85655b7e7a18742bda3e5ecbf6e6f07732d9503d Binary files /dev/null and "b/examples/\350\215\211\347\263\273\347\262\276\347\201\265.webp" differ diff --git "a/examples/\350\223\235\350\211\262\345\260\217\346\200\252\347\211\251.webp" "b/examples/\350\223\235\350\211\262\345\260\217\346\200\252\347\211\251.webp" new file mode 100644 index 0000000000000000000000000000000000000000..ab1208f2f49fba628a84b29cb86bf306ea02dbb3 Binary files /dev/null and "b/examples/\350\223\235\350\211\262\345\260\217\346\200\252\347\211\251.webp" differ diff --git "a/examples/\350\223\235\350\211\262\346\263\241\346\263\241\351\251\254\347\211\271.webp" "b/examples/\350\223\235\350\211\262\346\263\241\346\263\241\351\251\254\347\211\271.webp" new file mode 100644 index 0000000000000000000000000000000000000000..0a6b5018021b4eb5c30e2d256e569805e46f5502 Binary files /dev/null and "b/examples/\350\223\235\350\211\262\346\263\241\346\263\241\351\251\254\347\211\271.webp" differ diff --git "a/examples/\350\223\235\350\211\262\347\214\253.webp" "b/examples/\350\223\235\350\211\262\347\214\253.webp" new file mode 100644 index 0000000000000000000000000000000000000000..1019fd1b89ca8a61f35d4af8c7e68d0405bf276c Binary files /dev/null and "b/examples/\350\223\235\350\211\262\347\214\253.webp" differ diff --git "a/examples/\350\265\233\345\215\232\346\234\213\345\205\213-\347\224\267.webp" "b/examples/\350\265\233\345\215\232\346\234\213\345\205\213-\347\224\267.webp" new file mode 100644 index 0000000000000000000000000000000000000000..3d8bc3979499e136b744ab35f835e061ff137599 Binary files /dev/null and "b/examples/\350\265\233\345\215\232\346\234\213\345\205\213-\347\224\267.webp" differ diff --git "a/examples/\350\277\220\345\212\250\347\263\273\346\211\213\345\212\236.webp" "b/examples/\350\277\220\345\212\250\347\263\273\346\211\213\345\212\236.webp" new file mode 100644 index 0000000000000000000000000000000000000000..735d05181c639377431257068625d5d93772e1c8 Binary files /dev/null and "b/examples/\350\277\220\345\212\250\347\263\273\346\211\213\345\212\236.webp" differ diff --git "a/examples/\351\253\230\350\276\276.webp" "b/examples/\351\253\230\350\276\276.webp" new file mode 100644 index 0000000000000000000000000000000000000000..313c1441b536e09803f5e12ab4f850915acf62f7 Binary files /dev/null and "b/examples/\351\253\230\350\276\276.webp" differ diff --git "a/examples/\351\273\204\351\207\221\351\251\254\350\211\272\346\234\257\345\223\201.webp" "b/examples/\351\273\204\351\207\221\351\251\254\350\211\272\346\234\257\345\223\201.webp" new file mode 100644 index 0000000000000000000000000000000000000000..2966732d3b3d46e4ed1faecc4e6ebcdeab1b8c66 Binary files /dev/null and "b/examples/\351\273\204\351\207\221\351\251\254\350\211\272\346\234\257\345\223\201.webp" differ diff --git a/imagedream/__init__.py b/imagedream/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ba541c412090383da3a7e26113d00c6fe372daa --- /dev/null +++ b/imagedream/__init__.py @@ -0,0 +1 @@ +from .model_zoo import build_model diff --git a/imagedream/camera_utils.py b/imagedream/camera_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e3a5d80755135f1c279ec49e0fd46c64962cde0f --- /dev/null +++ b/imagedream/camera_utils.py @@ -0,0 +1,99 @@ +import numpy as np +import torch + + +def create_camera_to_world_matrix(elevation, azimuth): + elevation = np.radians(elevation) + azimuth = np.radians(azimuth) + # Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere + x = np.cos(elevation) * np.sin(azimuth) + y = np.sin(elevation) + z = np.cos(elevation) * np.cos(azimuth) + + # Calculate camera position, target, and up vectors + camera_pos = np.array([x, y, z]) + target = np.array([0, 0, 0]) + up = np.array([0, 1, 0]) + + # Construct view matrix + forward = target - camera_pos + forward /= np.linalg.norm(forward) + right = np.cross(forward, up) + right /= np.linalg.norm(right) + new_up = np.cross(right, forward) + new_up /= np.linalg.norm(new_up) + cam2world = np.eye(4) + cam2world[:3, :3] = np.array([right, new_up, -forward]).T + cam2world[:3, 3] = camera_pos + return cam2world + + +def convert_opengl_to_blender(camera_matrix): + if isinstance(camera_matrix, np.ndarray): + # Construct transformation matrix to convert from OpenGL space to Blender space + flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) + camera_matrix_blender = np.dot(flip_yz, camera_matrix) + else: + # Construct transformation matrix to convert from OpenGL space to Blender space + flip_yz = torch.tensor( + [[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]] + ) + if camera_matrix.ndim == 3: + flip_yz = flip_yz.unsqueeze(0) + camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix) + return camera_matrix_blender + + +def normalize_camera(camera_matrix): + """normalize the camera location onto a unit-sphere""" + if isinstance(camera_matrix, np.ndarray): + camera_matrix = camera_matrix.reshape(-1, 4, 4) + translation = camera_matrix[:, :3, 3] + translation = translation / ( + np.linalg.norm(translation, axis=1, keepdims=True) + 1e-8 + ) + camera_matrix[:, :3, 3] = translation + else: + camera_matrix = camera_matrix.reshape(-1, 4, 4) + translation = camera_matrix[:, :3, 3] + translation = translation / ( + torch.norm(translation, dim=1, keepdim=True) + 1e-8 + ) + camera_matrix[:, :3, 3] = translation + return camera_matrix.reshape(-1, 16) + + +def get_camera( + num_frames, + elevation=15, + azimuth_start=0, + azimuth_span=360, + blender_coord=True, + extra_view=False, +): + angle_gap = azimuth_span / num_frames + cameras = [] + for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap): + camera_matrix = create_camera_to_world_matrix(elevation, azimuth) + if blender_coord: + camera_matrix = convert_opengl_to_blender(camera_matrix) + cameras.append(camera_matrix.flatten()) + + if extra_view: + dim = len(cameras[0]) + cameras.append(np.zeros(dim)) + return torch.tensor(np.stack(cameras, 0)).float() + + +def get_camera_for_index(data_index): + """ + 按照当前我们的数据格式, 以000为正对我们的情况: + 000是正面, ev: 0, azimuth: 0 + 001是左边, ev: 0, azimuth: -90 + 002是下面, ev: -90, azimuth: 0 + 003是背面, ev: 0, azimuth: 180 + 004是右边, ev: 0, azimuth: 90 + 005是上面, ev: 90, azimuth: 0 + """ + params = [(0, 0), (0, -90), (-90, 0), (0, 180), (0, 90), (90, 0)] + return get_camera(1, *params[data_index]) \ No newline at end of file diff --git a/imagedream/configs/sd_v2_base_ipmv.yaml b/imagedream/configs/sd_v2_base_ipmv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..234fa01d97eba0fbc1c700d920d4dd4947d9b309 --- /dev/null +++ b/imagedream/configs/sd_v2_base_ipmv.yaml @@ -0,0 +1,61 @@ +model: + target: imagedream.ldm.interface.LatentDiffusionInterface + params: + linear_start: 0.00085 + linear_end: 0.0120 + timesteps: 1000 + scale_factor: 0.18215 + parameterization: "eps" + + unet_config: + target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + use_checkpoint: False + legacy: False + camera_dim: 16 + with_ip: True + ip_dim: 16 # ip token length + ip_mode: "local_resample" + + vae_config: + target: imagedream.ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + clip_config: + target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + ip_mode: "local_resample" diff --git a/imagedream/configs/sd_v2_base_ipmv_ch8.yaml b/imagedream/configs/sd_v2_base_ipmv_ch8.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1979885bb2a22db2002cfe52345ffcb29f64addc --- /dev/null +++ b/imagedream/configs/sd_v2_base_ipmv_ch8.yaml @@ -0,0 +1,61 @@ +model: + target: imagedream.ldm.interface.LatentDiffusionInterface + params: + linear_start: 0.00085 + linear_end: 0.0120 + timesteps: 1000 + scale_factor: 0.18215 + parameterization: "eps" + + unet_config: + target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel + params: + image_size: 32 # unused + in_channels: 8 + out_channels: 8 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + use_checkpoint: False + legacy: False + camera_dim: 16 + with_ip: True + ip_dim: 16 # ip token length + ip_mode: "local_resample" + + vae_config: + target: imagedream.ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + clip_config: + target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + ip_mode: "local_resample" diff --git a/imagedream/configs/sd_v2_base_ipmv_chin8.yaml b/imagedream/configs/sd_v2_base_ipmv_chin8.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b6010ded905a3233d0474491c374f45ff6649503 --- /dev/null +++ b/imagedream/configs/sd_v2_base_ipmv_chin8.yaml @@ -0,0 +1,61 @@ +model: + target: imagedream.ldm.interface.LatentDiffusionInterface + params: + linear_start: 0.00085 + linear_end: 0.0120 + timesteps: 1000 + scale_factor: 0.18215 + parameterization: "eps" + + unet_config: + target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModelStage2 + params: + image_size: 32 # unused + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + use_checkpoint: False + legacy: False + camera_dim: 16 + with_ip: True + ip_dim: 16 # ip token length + ip_mode: "local_resample" + + vae_config: + target: imagedream.ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + clip_config: + target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + ip_mode: "local_resample" diff --git a/imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml b/imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fa9e1c8fc87dfef6c06ef2db3160d828ab95ae2f --- /dev/null +++ b/imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml @@ -0,0 +1,62 @@ +model: + target: imagedream.ldm.interface.LatentDiffusionInterface + params: + linear_start: 0.00085 + linear_end: 0.0120 + timesteps: 1000 + scale_factor: 0.18215 + parameterization: "eps" + zero_snr: true + + unet_config: + target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModelStage2 + params: + image_size: 32 # unused + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + use_checkpoint: False + legacy: False + camera_dim: 16 + with_ip: True + ip_dim: 16 # ip token length + ip_mode: "local_resample" + + vae_config: + target: imagedream.ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + clip_config: + target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + ip_mode: "local_resample" diff --git a/imagedream/configs/sd_v2_base_ipmv_local.yaml b/imagedream/configs/sd_v2_base_ipmv_local.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f50e72870272346632d91849723a49221317ee30 --- /dev/null +++ b/imagedream/configs/sd_v2_base_ipmv_local.yaml @@ -0,0 +1,62 @@ +model: + target: imagedream.ldm.interface.LatentDiffusionInterface + params: + linear_start: 0.00085 + linear_end: 0.0120 + timesteps: 1000 + scale_factor: 0.18215 + parameterization: "eps" + + unet_config: + target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + use_checkpoint: False + legacy: False + camera_dim: 16 + with_ip: True + ip_dim: 16 # ip token length + ip_mode: "local_resample" + ip_weight: 1.0 # adjust for similarity to image + + vae_config: + target: imagedream.ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + clip_config: + target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + ip_mode: "local_resample" diff --git a/imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml b/imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d0115e086043c4dca06dfbf1a0fd03b4f4b6bec8 --- /dev/null +++ b/imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml @@ -0,0 +1,62 @@ +model: + target: imagedream.ldm.interface.LatentDiffusionInterface + params: + linear_start: 0.00085 + linear_end: 0.0120 + timesteps: 1000 + scale_factor: 0.18215 + parameterization: "eps" + zero_snr: true + + unet_config: + target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + use_checkpoint: False + legacy: False + camera_dim: 16 + with_ip: True + ip_dim: 16 # ip token length + ip_mode: "local_resample" + + vae_config: + target: imagedream.ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + clip_config: + target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + ip_mode: "local_resample" diff --git a/imagedream/ldm/__init__.py b/imagedream/ldm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imagedream/ldm/interface.py b/imagedream/ldm/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..f7d417c895489ba514008fd11bfc18571897113a --- /dev/null +++ b/imagedream/ldm/interface.py @@ -0,0 +1,205 @@ +from typing import List +from functools import partial + +import numpy as np +import torch +import torch.nn as nn + +from .modules.diffusionmodules.util import ( + make_beta_schedule, + extract_into_tensor, + enforce_zero_terminal_snr, + noise_like, +) +from .util import exists, default, instantiate_from_config +from .modules.distributions.distributions import DiagonalGaussianDistribution + + +class DiffusionWrapper(nn.Module): + def __init__(self, diffusion_model): + super().__init__() + self.diffusion_model = diffusion_model + + def forward(self, *args, **kwargs): + return self.diffusion_model(*args, **kwargs) + + +class LatentDiffusionInterface(nn.Module): + """a simple interface class for LDM inference""" + + def __init__( + self, + unet_config, + clip_config, + vae_config, + parameterization="eps", + scale_factor=0.18215, + beta_schedule="linear", + timesteps=1000, + linear_start=0.00085, + linear_end=0.0120, + cosine_s=8e-3, + given_betas=None, + zero_snr=False, + *args, + **kwargs, + ): + super().__init__() + + unet = instantiate_from_config(unet_config) + self.model = DiffusionWrapper(unet) + self.clip_model = instantiate_from_config(clip_config) + self.vae_model = instantiate_from_config(vae_config) + + self.parameterization = parameterization + self.scale_factor = scale_factor + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + zero_snr=zero_snr + ) + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + zero_snr=False + ): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + if zero_snr: + print("--- using zero snr---") + betas = enforce_zero_terminal_snr(betas).numpy() + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.v_posterior = 0 + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + "posterior_log_variance_clipped", + to_torch(np.log(np.maximum(posterior_variance, 1e-20))), + ) + self.register_buffer( + "posterior_mean_coef1", + to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), + ) + self.register_buffer( + "posterior_mean_coef2", + to_torch( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ), + ) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + * noise + ) + + def predict_start_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def predict_eps_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) + * x_t + ) + + def apply_model(self, x_noisy, t, cond, **kwargs): + assert isinstance(cond, dict), "cond has to be a dictionary" + return self.model(x_noisy, t, **cond, **kwargs) + + def get_learned_conditioning(self, prompts: List[str]): + return self.clip_model(prompts) + + def get_learned_image_conditioning(self, images): + return self.clip_model.forward_image(images) + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z + + def encode_first_stage(self, x): + return self.vae_model.encode(x) + + def decode_first_stage(self, z): + z = 1.0 / self.scale_factor * z + return self.vae_model.decode(z) diff --git a/imagedream/ldm/models/__init__.py b/imagedream/ldm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imagedream/ldm/models/autoencoder.py b/imagedream/ldm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8941041399e08f538c5e32f523eba54889c5bd4c --- /dev/null +++ b/imagedream/ldm/models/autoencoder.py @@ -0,0 +1,270 @@ +import torch +import torch.nn.functional as F +from contextlib import contextmanager + +from ..modules.diffusionmodules.model import Encoder, Decoder +from ..modules.distributions.distributions import DiagonalGaussianDistribution + +from ..util import instantiate_from_config +from ..modules.ema import LitEma + + +class AutoencoderKL(torch.nn.Module): + def __init__( + self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False, + ): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = ema_decay is not None + if self.use_ema: + self.ema_decay = ema_decay + assert 0.0 < ema_decay < 1.0 + self.model_ema = LitEma(self, decay=ema_decay) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + self.log( + "aeloss", + aeloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict( + log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False + ) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + + self.log( + "discloss", + discloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict( + log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False + ) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, postfix=""): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + postfix, + ) + + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + postfix, + ) + + self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + ae_params_list = ( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()) + ) + if self.learn_logvar: + print(f"{self.__class__.__name__}: Learning logvar") + ae_params_list.append(self.loss.logvar) + opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam( + self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) + ) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log["samples_ema"] = self.decode( + torch.randn_like(posterior_ema.sample()) + ) + log["reconstructions_ema"] = xrec_ema + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/imagedream/ldm/models/diffusion/__init__.py b/imagedream/ldm/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imagedream/ldm/models/diffusion/ddim.py b/imagedream/ldm/models/diffusion/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e1372f2ec75b16d3c35ca7fa7111321f8a4ade --- /dev/null +++ b/imagedream/ldm/models/diffusion/ddim.py @@ -0,0 +1,430 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ...modules.diffusionmodules.util import ( + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, + extract_into_tensor, +) + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print( + f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + ) + else: + if conditioning.shape[0] != batch_size: + print( + f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + ) + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + samples, intermediates = self.ddim_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + **kwargs, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + **kwargs, + ): + """ + when inference time: all values of parameter + cond.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img']) + shape: (5, 4, 32, 32) + x_T: None + ddim_use_original_steps: False + timesteps: None + callback: None + quantize_denoised: False + mask: None + image_callback: None + log_every_t: 100 + temperature: 1.0 + noise_dropout: 0.0 + score_corrector: None + corrector_kwargs: None + unconditional_guidance_scale: 5 + unconditional_conditioning.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img']) + kwargs: {} + """ + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) # shape: torch.Size([5, 4, 32, 32]) mean: -0.00, std: 1.00, min: -3.64, max: 3.94 + else: + img = x_T + + if timesteps is None: # equal with set time step in hf + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = ( # reversed timesteps + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample( + x0, ts + ) # TODO: deterministic forward pass? + img = img_orig * mask + (1.0 - mask) * img + + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + **kwargs, + ) + img, pred_x0 = outs + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + **kwargs, + ): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: + model_output = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [ + torch.cat([unconditional_conditioning[k][i], c[k][i]]) + for i in range(len(c[k])) + ] + elif isinstance(c[k], torch.Tensor): + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) + else: + assert c[k] == unconditional_conditioning[k] + c_in[k] = c[k] + elif isinstance(c, list): + c_in = list() + assert isinstance(unconditional_conditioning, list) + for i in range(len(c)): + c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) + else: + c_in = torch.cat([unconditional_conditioning, c]) + model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + model_output = model_uncond + unconditional_guidance_scale * ( + model_t - model_uncond + ) + + + if self.model.parameterization == "v": + print("using v!") + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output + + if score_corrector is not None: + assert self.model.parameterization == "eps", "not implemented" + e_t = score_corrector.modify_score( + self.model, e_t, x, t, c, **corrector_kwargs + ) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + if self.model.parameterization != "v": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) + + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + raise NotImplementedError() + + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + ) + + @torch.no_grad() + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + **kwargs, + ): + timesteps = ( + np.arange(self.ddpm_num_timesteps) + if use_original_steps + else self.ddim_timesteps + ) + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + + iterator = tqdm(time_range, desc="Decoding image", total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full( + (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long + ) + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + **kwargs, + ) + return x_dec diff --git a/imagedream/ldm/modules/__init__.py b/imagedream/ldm/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imagedream/ldm/modules/attention.py b/imagedream/ldm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..0693788a7dc44170aaee8bc90c13c4dfee13e4f9 --- /dev/null +++ b/imagedream/ldm/modules/attention.py @@ -0,0 +1,456 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat +from typing import Optional, Any + +from .diffusionmodules.util import checkpoint + + +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + +# CrossAttn precision handling +import os + +_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs): + super().__init__() + print( + f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads." + ) + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.with_ip = kwargs.get("with_ip", False) + if self.with_ip and (context_dim is not None): + self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False) + self.ip_dim= kwargs.get("ip_dim", 16) + self.ip_weight = kwargs.get("ip_weight", 1.0) + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + + has_ip = self.with_ip and (context is not None) + if has_ip: + # context dim [(b frame_num), (77 + img_token), 1024] + token_len = context.shape[1] + context_ip = context[:, -self.ip_dim:, :] + k_ip = self.to_k_ip(context_ip) + v_ip = self.to_v_ip(context_ip) + context = context[:, :(token_len - self.ip_dim), :] + + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op + ) + + if has_ip: + k_ip, v_ip = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (k_ip, v_ip), + ) + # actually compute the attention, what we cannot get enough of + out_ip = xformers.ops.memory_efficient_attention( + q, k_ip, v_ip, attn_bias=None, op=self.attention_op + ) + out = out + self.ip_weight * out_ip + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + **kwargs + ): + super().__init__() + assert XFORMERS_IS_AVAILBLE, "xformers is not available" + attn_cls = MemoryEfficientCrossAttention + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, + ) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + **kwargs + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint( + self._forward, (x, context), self.parameters(), self.checkpoint + ) + + def _forward(self, x, context=None): + x = ( + self.attn1( + self.norm1(x), context=context if self.disable_self_attn else None + ) + + x + ) + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True, + **kwargs + ): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint, + **kwargs + ) + for d in range(depth) + ] + ) + if not use_linear: + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +class BasicTransformerBlock3D(BasicTransformerBlock): + def forward(self, x, context=None, num_frames=1): + return checkpoint( + self._forward, (x, context, num_frames), self.parameters(), self.checkpoint + ) + + def _forward(self, x, context=None, num_frames=1): + x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous() + x = ( + self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None + ) + + x + ) + x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous() + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer3D(nn.Module): + """3D self-attention""" + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True, + **kwargs + ): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock3D( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint, + **kwargs + ) + for d in range(depth) + ] + ) + if not use_linear: + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None, num_frames=1): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i], num_frames=num_frames) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in diff --git a/imagedream/ldm/modules/diffusionmodules/__init__.py b/imagedream/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imagedream/ldm/modules/diffusionmodules/adaptors.py b/imagedream/ldm/modules/diffusionmodules/adaptors.py new file mode 100644 index 0000000000000000000000000000000000000000..547e7f4173c85f7f62ff76e38e9794ee2f31660e --- /dev/null +++ b/imagedream/ldm/modules/diffusionmodules/adaptors.py @@ -0,0 +1,163 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +import math + +import torch +import torch.nn as nn + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class ImageProjModel(torch.nn.Module): + """Projection Model""" + def __init__(self, + cross_attention_dim=1024, + clip_embeddings_dim=1024, + clip_extra_context_tokens=4): + super().__init__() + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + + # from 1024 -> 4 * 1024 + self.proj = torch.nn.Linear( + clip_embeddings_dim, + self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + + +class SimpleReSampler(nn.Module): + def __init__(self, embedding_dim=1280, output_dim=1024): + super().__init__() + self.proj_out = nn.Linear(embedding_dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + def forward(self, latents): + """ + latents: B 256 N + """ + latents = self.proj_out(latents) + return self.norm_out(latents) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + self.proj_in = nn.Linear(embedding_dim, dim) + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, + dim_head=dim_head, + heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + latents = self.latents.repeat(x.size(0), 1, 1) + x = self.proj_in(x) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +if __name__ == '__main__': + resampler = Resampler(embedding_dim=1280) + resampler = SimpleReSampler(embedding_dim=1280) + tensor = torch.rand(4, 257, 1280) + embed = resampler(tensor) + # embed = (tensor) + print(embed.shape) diff --git a/imagedream/ldm/modules/diffusionmodules/model.py b/imagedream/ldm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..090df6b33cccee7950a1d5719ec7513cc61d588f --- /dev/null +++ b/imagedream/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,1018 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange +from typing import Optional, Any + +from ..attention import MemoryEfficientCrossAttention + +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + print("No module 'xformers'. Proceeding without it.") + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + + # + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.attention_op: Optional[Any] = None + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op + ) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return x + out + + +class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): + def forward(self, x, context=None, mask=None): + b, c, h, w = x.shape + x = rearrange(x, "b c h w -> b (h w) c") + out = super().forward(x, context=context, mask=mask) + out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c) + return x + out + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in [ + "vanilla", + "vanilla-xformers", + "memory-efficient-cross-attn", + "linear", + "none", + ], f"attn_type {attn_type} unknown" + if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": + attn_type = "vanilla-xformers" + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "vanilla-xformers": + print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + return MemoryEfficientAttnBlock(in_channels) + elif type == "memory-efficient-cross-attn": + attn_kwargs["query_dim"] = in_channels + return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + raise NotImplementedError() + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock( + in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0, + ): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, stride=1, padding=1 + ) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate( + x, + size=( + int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)), + ), + ) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1.0 + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler( + factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels, + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1 + ) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate( + x, mode=self.mode, align_corners=False, scale_factor=scale_factor + ) + return x diff --git a/imagedream/ldm/modules/diffusionmodules/openaimodel.py b/imagedream/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..3d28ecbc9718438a8c6b720d78b2bf52f641d4cc --- /dev/null +++ b/imagedream/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,1135 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from imagedream.ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, + convert_module_to_f16, + convert_module_to_f32 +) +from imagedream.ldm.modules.attention import ( + SpatialTransformer, + SpatialTransformer3D, + exists +) +from imagedream.ldm.modules.diffusionmodules.adaptors import ( + Resampler, + ImageProjModel +) + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None, num_frames=1): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer3D): + x = layer(x, context, num_frames=num_frames) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d( + self.channels, self.out_channels, kernel_size=ks, stride=2 + ) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class Timestep(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, t): + return timestep_embedding(t, self.dim) + + +class MultiViewUNetModel(nn.Module): + """ + The full multi-view UNet model with attention, timestep embedding and camera embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + :param camera_dim: dimensionality of camera input. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + use_bf16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + adm_in_channels=None, + camera_dim=None, + with_ip=False, # wether add image prompt images + ip_dim=0, # number of extra token, 4 for global 16 for local + ip_weight=1.0, # weight for image prompt context + ip_mode="local_resample", # which mode of adaptor, global or local + ): + super().__init__() + if use_spatial_transformer: + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." + + if context_dim is not None: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + from omegaconf.listconfig import ListConfig + + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all( + map( + lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], + range(len(num_attention_blocks)), + ) + ) + print( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.dtype = th.bfloat16 if use_bf16 else self.dtype + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + self.with_ip = with_ip # wether there is image prompt + self.ip_dim = ip_dim # num of extra token, 4 for global 16 for local + self.ip_weight = ip_weight + assert ip_mode in ["global", "local_resample"] + self.ip_mode = ip_mode # which mode of adaptor + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if camera_dim is not None: + time_embed_dim = model_channels * 4 + self.camera_embed = nn.Sequential( + linear(camera_dim, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + if self.with_ip and (context_dim is not None) and ip_dim > 0: + if self.ip_mode == "local_resample": + # ip-adapter-plus + hidden_dim = 1280 + self.image_embed = Resampler( + dim=context_dim, + depth=4, + dim_head=64, + heads=12, + num_queries=ip_dim, # num token + embedding_dim=hidden_dim, + output_dim=context_dim, + ff_mult=4, + ) + elif self.ip_mode == "global": + self.image_embed = ImageProjModel( + cross_attention_dim=context_dim, + clip_extra_context_tokens=ip_dim) + else: + raise ValueError(f"{self.ip_mode} is not supported") + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if ( + not exists(num_attention_blocks) + or nr < num_attention_blocks[level] + ): + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer3D( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + with_ip=self.with_ip, + ip_dim=self.ip_dim, + ip_weight=self.ip_weight + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer3D( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + with_ip=self.with_ip, + ip_dim=self.ip_dim, + ip_weight=self.ip_weight + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if ( + not exists(num_attention_blocks) + or i < num_attention_blocks[level] + ): + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer3D( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + with_ip=self.with_ip, + ip_dim=self.ip_dim, + ip_weight=self.ip_weight + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward( + self, + x, + timesteps=None, + context=None, + y=None, + camera=None, + num_frames=1, + **kwargs, + ): + """ + Apply the model to an input batch. + :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views). + :param timesteps: a 1-D batch of timesteps. + :param context: a dict conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional, default None. + :param num_frames: a integer indicating number of frames for tensor reshaping. + :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views). + """ + assert ( + x.shape[0] % num_frames == 0 + ), "[UNet] input batch size must be dividable by num_frames!" + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # shape: torch.Size([B, 320]) mean: 0.18, std: 0.68, min: -1.00, max: 1.00 + emb = self.time_embed(t_emb) # shape: torch.Size([B, 1280]) mean: 0.12, std: 0.57, min: -5.73, max: 6.51 + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + # Add camera embeddings + if camera is not None: + assert camera.shape[0] == emb.shape[0] + # camera embed: shape: torch.Size([B, 1280]) mean: -0.02, std: 0.27, min: -7.23, max: 2.04 + emb = emb + self.camera_embed(camera) + ip = kwargs.get("ip", None) + ip_img = kwargs.get("ip_img", None) + + if ip_img is not None: + x[(num_frames-1)::num_frames, :, :, :] = ip_img + + if ip is not None: + ip_emb = self.image_embed(ip) # shape: torch.Size([B, 16, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31 + context = torch.cat((context, ip_emb), 1) # shape: torch.Size([B, 93, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31 + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context, num_frames=num_frames) + hs.append(h) + h = self.middle_block(h, emb, context, num_frames=num_frames) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context, num_frames=num_frames) + h = h.type(x.dtype) # shape: torch.Size([10, 320, 32, 32]) mean: -0.67, std: 3.96, min: -42.74, max: 25.58 + if self.predict_codebook_ids: # False + return self.id_predictor(h) + else: + return self.out(h) # shape: torch.Size([10, 4, 32, 32]) mean: -0.00, std: 0.91, min: -3.65, max: 3.93 + + + + +class MultiViewUNetModelStage2(MultiViewUNetModel): + """ + The full multi-view UNet model with attention, timestep embedding and camera embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + :param camera_dim: dimensionality of camera input. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + use_bf16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + adm_in_channels=None, + camera_dim=None, + with_ip=False, # wether add image prompt images + ip_dim=0, # number of extra token, 4 for global 16 for local + ip_weight=1.0, # weight for image prompt context + ip_mode="local_resample", # which mode of adaptor, global or local + ): + super().__init__( + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout, + channel_mult, + conv_resample, + dims, + num_classes, + use_checkpoint, + use_fp16, + use_bf16, + num_heads, + num_head_channels, + num_heads_upsample, + use_scale_shift_norm, + resblock_updown, + use_new_attention_order, + use_spatial_transformer, + transformer_depth, + context_dim, + n_embed, + legacy, + disable_self_attentions, + num_attention_blocks, + disable_middle_self_attn, + use_linear_in_transformer, + adm_in_channels, + camera_dim, + with_ip, + ip_dim, + ip_weight, + ip_mode, + ) + + def forward( + self, + x, + timesteps=None, + context=None, + y=None, + camera=None, + num_frames=1, + **kwargs, + ): + """ + Apply the model to an input batch. + :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views). + :param timesteps: a 1-D batch of timesteps. + :param context: a dict conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional, default None. + :param num_frames: a integer indicating number of frames for tensor reshaping. + :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views). + """ + assert ( + x.shape[0] % num_frames == 0 + ), "[UNet] input batch size must be dividable by num_frames!" + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # shape: torch.Size([B, 320]) mean: 0.18, std: 0.68, min: -1.00, max: 1.00 + emb = self.time_embed(t_emb) # shape: torch.Size([B, 1280]) mean: 0.12, std: 0.57, min: -5.73, max: 6.51 + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + # Add camera embeddings + if camera is not None: + assert camera.shape[0] == emb.shape[0] + # camera embed: shape: torch.Size([B, 1280]) mean: -0.02, std: 0.27, min: -7.23, max: 2.04 + emb = emb + self.camera_embed(camera) + ip = kwargs.get("ip", None) + ip_img = kwargs.get("ip_img", None) + pixel_images = kwargs.get("pixel_images", None) + + if ip_img is not None: + x[(num_frames-1)::num_frames, :, :, :] = ip_img + + x = torch.cat((x, pixel_images), dim=1) + + if ip is not None: + ip_emb = self.image_embed(ip) # shape: torch.Size([B, 16, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31 + context = torch.cat((context, ip_emb), 1) # shape: torch.Size([B, 93, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31 + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context, num_frames=num_frames) + hs.append(h) + h = self.middle_block(h, emb, context, num_frames=num_frames) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context, num_frames=num_frames) + h = h.type(x.dtype) # shape: torch.Size([10, 320, 32, 32]) mean: -0.67, std: 3.96, min: -42.74, max: 25.58 + if self.predict_codebook_ids: # False + return self.id_predictor(h) + else: + return self.out(h) # shape: torch.Size([10, 4, 32, 32]) mean: -0.00, std: 0.91, min: -3.65, max: 3.93 + \ No newline at end of file diff --git a/imagedream/ldm/modules/diffusionmodules/util.py b/imagedream/ldm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..665af3eeb177bb424179e7b9b06008c510533c5d --- /dev/null +++ b/imagedream/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,353 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat +import importlib + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == "sqrt": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + +def enforce_zero_terminal_snr(betas): + betas = torch.tensor(betas) if not isinstance(betas, torch.Tensor) else betas + # Convert betas to alphas_bar_sqrt + alphas =1 - betas + alphas_bar = alphas.cumprod(0) + alphas_bar_sqrt = alphas_bar.sqrt() + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + # Shift so last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + # Scale so first timestep is back to old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt ** 2 + alphas = alphas_bar[1:] / alphas_bar[:-1] + alphas = torch.cat ([alphas_bar[0:1], alphas]) + betas = 1 - alphas + return betas + + +def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True +): + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) + if verbose: + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + # import pdb; pdb.set_trace() + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +# dummy replace +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.float() + if l.bias is not None: + l.bias.data = l.bias.data.float() diff --git a/imagedream/ldm/modules/distributions/__init__.py b/imagedream/ldm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imagedream/ldm/modules/distributions/distributions.py b/imagedream/ldm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..0b61f03077358ce4737c85842d9871f70dabb656 --- /dev/null +++ b/imagedream/ldm/modules/distributions/distributions.py @@ -0,0 +1,102 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/imagedream/ldm/modules/ema.py b/imagedream/ldm/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..97b5ae2b230f89b4dba57e44c4f851478ad86f68 --- /dev/null +++ b/imagedream/ldm/modules/ema.py @@ -0,0 +1,86 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/imagedream/ldm/modules/encoders/__init__.py b/imagedream/ldm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imagedream/ldm/modules/encoders/modules.py b/imagedream/ldm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..8efe14f3aeb29cd93203ea0f32a36f624c81aafd --- /dev/null +++ b/imagedream/ldm/modules/encoders/modules.py @@ -0,0 +1,329 @@ +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel + +import numpy as np +import open_clip +from PIL import Image +from ...util import default, count_params + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class IdentityEncoder(AbstractEncoder): + def encode(self, x): + return x + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.ucg_rate = ucg_rate + + def forward(self, batch, key=None, disable_dropout=False): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + if self.ucg_rate > 0.0 and not disable_dropout: + mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) + c = c.long() + c = self.embedding(c) + return c + + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = ( + self.n_classes - 1 + ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs,), device=device) * uc_class + uc = {self.key: uc} + return uc + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + + def __init__( + self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cuda", + max_length=77, + freeze=True, + layer="last", + layer_idx=None, + ): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer( + input_ids=tokens, output_hidden_states=self.layer == "hidden" + ) + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + return z + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPEmbedder(AbstractEncoder, nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + + LAYERS = [ + # "pooled", + "last", + "penultimate", + ] + + def __init__( + self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + layer="last", + ip_mode=None + ): + """_summary_ + + Args: + ip_mode (str, optional): what is the image promcessing mode. Defaults to None. + + """ + super().__init__() + assert layer in self.LAYERS + model, _, preprocess = open_clip.create_model_and_transforms( + arch, device=torch.device("cpu"), pretrained=version + ) + if ip_mode is None: + del model.visual + + self.model = model + self.preprocess = preprocess + self.device = device + self.max_length = max_length + self.ip_mode = ip_mode + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def forward_image(self, pil_image): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + if isinstance(pil_image, torch.Tensor): + pil_image = pil_image.cpu().numpy() + if isinstance(pil_image, np.ndarray): + if pil_image.ndim == 3: + pil_image = pil_image[None, :, :, :] + pil_image = [Image.fromarray(x) for x in pil_image] + + images = [] + for image in pil_image: + images.append(self.preprocess(image).to(self.device)) + + image = torch.stack(images, 0) # to [b, 3, h, w] + if self.ip_mode == "global": + image_features = self.model.encode_image(image) + image_features /= image_features.norm(dim=-1, keepdim=True) + elif "local" in self.ip_mode: + image_features = self.encode_image_with_transformer(image) + + return image_features # b, l + + def encode_image_with_transformer(self, x): + visual = self.model.visual + x = visual.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat( + [visual.class_embedding.to(x.dtype) + \ + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + visual.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + # x = visual.patch_dropout(x) + x = visual.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + hidden = self.image_transformer_forward(x) + x = hidden[-2].permute(1, 0, 2) # LND -> NLD + return x + + def image_transformer_forward(self, x): + encoder_states = () + trans = self.model.visual.transformer + for r in trans.resblocks: + if trans.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, None) + else: + x = r(x, attn_mask=None) + encoder_states = encoder_states + (x, ) + return encoder_states + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if ( + self.model.transformer.grad_checkpointing + and not torch.jit.is_scripting() + ): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenCLIPT5Encoder(AbstractEncoder): + def __init__( + self, + clip_version="openai/clip-vit-large-patch14", + t5_version="google/t5-v1_1-xl", + device="cuda", + clip_max_length=77, + t5_max_length=77, + ): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder( + clip_version, device, max_length=clip_max_length + ) + self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) + print( + f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params." + ) + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] diff --git a/imagedream/ldm/util.py b/imagedream/ldm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..f541504a1181f6164b1084a6ffb0de365a9b06ba --- /dev/null +++ b/imagedream/ldm/util.py @@ -0,0 +1,226 @@ +import importlib + +import random +import torch +import numpy as np +from collections import abc + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join( + xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) + ) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + # import pdb; pdb.set_trace() + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + # import pdb; pdb.set_trace() + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, + data, + n_proc, + target_data_type="ndarray", + cpu_intensive=True, + use_worker_id=False, +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i : i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == "ndarray": + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == "list": + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res + +def set_seed(seed=None): + random.seed(seed) + np.random.seed(seed) + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def add_random_background(image, bg_color=None): + bg_color = np.random.rand() * 255 if bg_color is None else bg_color + image = np.array(image) + rgb, alpha = image[..., :3], image[..., 3:] + alpha = alpha.astype(np.float32) / 255.0 + image_new = rgb * alpha + bg_color * (1 - alpha) + return Image.fromarray(image_new.astype(np.uint8)) \ No newline at end of file diff --git a/imagedream/model_zoo.py b/imagedream/model_zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..d65529ab1b717e3d60a0d38e2fb9fea1abf0c006 --- /dev/null +++ b/imagedream/model_zoo.py @@ -0,0 +1,64 @@ +""" Utiliy functions to load pre-trained models more easily """ +import os +import pkg_resources +from omegaconf import OmegaConf + +import torch +from huggingface_hub import hf_hub_download + +from imagedream.ldm.util import instantiate_from_config + + +PRETRAINED_MODELS = { + "sd-v2.1-base-4view-ipmv": { + "config": "sd_v2_base_ipmv.yaml", + "repo_id": "Peng-Wang/ImageDream", + "filename": "sd-v2.1-base-4view-ipmv.pt", + }, + "sd-v2.1-base-4view-ipmv-local": { + "config": "sd_v2_base_ipmv_local.yaml", + "repo_id": "Peng-Wang/ImageDream", + "filename": "sd-v2.1-base-4view-ipmv-local.pt", + }, +} + + +def get_config_file(config_path): + cfg_file = pkg_resources.resource_filename( + "imagedream", os.path.join("configs", config_path) + ) + if not os.path.exists(cfg_file): + raise RuntimeError(f"Config {config_path} not available!") + return cfg_file + + +def build_model(model_name, config_path=None, ckpt_path=None, cache_dir=None): + if (config_path is not None) and (ckpt_path is not None): + config = OmegaConf.load(config_path) + model = instantiate_from_config(config.model) + model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False) + return model + + if not model_name in PRETRAINED_MODELS: + raise RuntimeError( + f"Model name {model_name} is not a pre-trained model. Available models are:\n- " + + "\n- ".join(PRETRAINED_MODELS.keys()) + ) + model_info = PRETRAINED_MODELS[model_name] + + # Instiantiate the model + print(f"Loading model from config: {model_info['config']}") + config_file = get_config_file(model_info["config"]) + config = OmegaConf.load(config_file) + model = instantiate_from_config(config.model) + + # Load pre-trained checkpoint from huggingface + if not ckpt_path: + ckpt_path = hf_hub_download( + repo_id=model_info["repo_id"], + filename=model_info["filename"], + cache_dir=cache_dir, + ) + print(f"Loading model from cache file: {ckpt_path}") + model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False) + return model diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..b8a1be926e2fb3381acf1dbf6772b0d393451d70 --- /dev/null +++ b/inference.py @@ -0,0 +1,91 @@ +import numpy as np +import torch +import time +import nvdiffrast.torch as dr +from util.utils import get_tri +import tempfile +from mesh import Mesh +import zipfile +def generate3d(model, rgb, ccm, device): + + color_tri = torch.from_numpy(rgb)/255 + xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)])/255 + color = color_tri.permute(2,0,1) + xyz = xyz_tri.permute(2,0,1) + + + def get_imgs(color): + # color : [C, H, W*6] + color_list = [] + color_list.append(color[:,:,256*5:256*(1+5)]) + for i in range(0,5): + color_list.append(color[:,:,256*i:256*(1+i)]) + return torch.stack(color_list, dim=0)# [6, C, H, W] + + triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)# [1, 6, H, W, C] + + color = get_imgs(color) + xyz = get_imgs(xyz) + + color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0) + xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0) + + triplane = torch.cat([color,xyz],dim=1).to(device) + # 3D visualize + model.eval() + glctx = dr.RasterizeCudaContext() + + if model.denoising == True: + tnew = 20 + tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device) + noise_new = torch.randn_like(triplane) *0.5+0.5 + triplane = model.scheduler.add_noise(triplane, noise_new, tnew) + start_time = time.time() + with torch.no_grad(): + triplane_feature2 = model.unet2(triplane,tnew) + end_time = time.time() + elapsed_time = end_time - start_time + print(f"unet takes {elapsed_time}s") + else: + triplane_feature2 = model.unet2(triplane) + + + with torch.no_grad(): + data_config = { + 'resolution': [1024, 1024], + "triview_color": triplane_color.to(device), + } + + verts, faces = model.decode(data_config, triplane_feature2) + + data_config['verts'] = verts[0] + data_config['faces'] = faces + + + from kiui.mesh_utils import clean_mesh + verts, faces = clean_mesh(data_config['verts'].squeeze().cpu().numpy().astype(np.float32), data_config['faces'].squeeze().cpu().numpy().astype(np.int32), repair = False, remesh=False, remesh_size=0.005) + data_config['verts'] = torch.from_numpy(verts).cuda().contiguous() + data_config['faces'] = torch.from_numpy(faces).cuda().contiguous() + + start_time = time.time() + with torch.no_grad(): + mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name + model.export_mesh_wt_uv(glctx, data_config, mesh_path_obj, "", device, res=(1024,1024), tri_fea_2=triplane_feature2) + + mesh = Mesh.load(mesh_path_obj+".obj", bound=0.9, front_dir="+z") + mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name + mesh.write(mesh_path_glb+".glb") + + # mesh_obj2 = trimesh.load(mesh_path_glb+".glb", file_type='glb') + # mesh_path_obj2 = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name + # mesh_obj2.export(mesh_path_obj2+".obj") + + with zipfile.ZipFile(mesh_path_obj+'.zip', 'w') as myzip: + myzip.write(mesh_path_obj+'.obj', mesh_path_obj.split("/")[-1]+'.obj') + myzip.write(mesh_path_obj+'.png', mesh_path_obj.split("/")[-1]+'.png') + myzip.write(mesh_path_obj+'.mtl', mesh_path_obj.split("/")[-1]+'.mtl') + + end_time = time.time() + elapsed_time = end_time - start_time + print(f"uv takes {elapsed_time}s") + return mesh_path_glb+".glb", mesh_path_obj+'.zip' diff --git a/libs/base_utils.py b/libs/base_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9233461bd699b9df025a7a532ea72a7c023b30f0 --- /dev/null +++ b/libs/base_utils.py @@ -0,0 +1,84 @@ +import numpy as np +import cv2 +import torch +import numpy as np +from PIL import Image + + +def instantiate_from_config(config): + if not "target" in config: + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + import importlib + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def tensor_detail(t): + assert type(t) == torch.Tensor + print(f"shape: {t.shape} mean: {t.mean():.2f}, std: {t.std():.2f}, min: {t.min():.2f}, max: {t.max():.2f}") + + + +def drawRoundRec(draw, color, x, y, w, h, r): + drawObject = draw + + '''Rounds''' + drawObject.ellipse((x, y, x + r, y + r), fill=color) + drawObject.ellipse((x + w - r, y, x + w, y + r), fill=color) + drawObject.ellipse((x, y + h - r, x + r, y + h), fill=color) + drawObject.ellipse((x + w - r, y + h - r, x + w, y + h), fill=color) + + '''rec.s''' + drawObject.rectangle((x + r / 2, y, x + w - (r / 2), y + h), fill=color) + drawObject.rectangle((x, y + r / 2, x + w, y + h - (r / 2)), fill=color) + + +def do_resize_content(original_image: Image, scale_rate): + # resize image content wile retain the original image size + if scale_rate != 1: + # Calculate the new size after rescaling + new_size = tuple(int(dim * scale_rate) for dim in original_image.size) + # Resize the image while maintaining the aspect ratio + resized_image = original_image.resize(new_size) + # Create a new image with the original size and black background + padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0)) + paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2) + padded_image.paste(resized_image, paste_position) + return padded_image + else: + return original_image + +def add_stroke(img, color=(255, 255, 255), stroke_radius=3): + # color in R, G, B format + if isinstance(img, Image.Image): + assert img.mode == "RGBA" + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGBA2BGRA) + else: + assert img.shape[2] == 4 + gray = img[:,:, 3] + ret, binary = cv2.threshold(gray,127,255,cv2.THRESH_BINARY) + contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE) + res = cv2.drawContours(img, contours,-1, tuple(color)[::-1] + (255,), stroke_radius) + return Image.fromarray(cv2.cvtColor(res,cv2.COLOR_BGRA2RGBA)) + +def make_blob(image_size=(512, 512), sigma=0.2): + """ + make 2D blob image with: + I(x, y)=1-\exp \left(-\frac{(x-H / 2)^2+(y-W / 2)^2}{2 \sigma^2 HS}\right) + """ + import numpy as np + H, W = image_size + x = np.arange(0, W, 1, float) + y = np.arange(0, H, 1, float) + x, y = np.meshgrid(x, y) + x0 = W // 2 + y0 = H // 2 + img = 1 - np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2 * H * W)) + return (img * 255).astype(np.uint8) \ No newline at end of file diff --git a/libs/sample.py b/libs/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..44ed9a20bd52a626dcc50199c1636a0c87ae85f5 --- /dev/null +++ b/libs/sample.py @@ -0,0 +1,380 @@ +import numpy as np +import torch +from imagedream.camera_utils import get_camera_for_index +from imagedream.ldm.util import set_seed, add_random_background +from libs.base_utils import do_resize_content +from imagedream.ldm.models.diffusion.ddim import DDIMSampler +from torchvision import transforms as T + + +class ImageDreamDiffusion: + def __init__( + self, + model, + device, + dtype, + mode, + num_frames, + camera_views, + ref_position, + random_background=False, + offset_noise=False, + resize_rate=1, + image_size=256, + seed=1234, + ) -> None: + assert mode in ["pixel", "local"] + size = image_size + self.seed = seed + batch_size = max(4, num_frames) + + neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear." + uc = model.get_learned_conditioning([neg_texts]).to(device) + sampler = DDIMSampler(model) + + # pre-compute camera matrices + camera = [get_camera_for_index(i).squeeze() for i in camera_views] + camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero + camera = torch.stack(camera) + camera = camera.repeat(batch_size // num_frames, 1).to(device) + + self.image_transform = T.Compose( + [ + T.Resize((size, size)), + T.ToTensor(), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + self.dtype = dtype + self.ref_position = ref_position + self.mode = mode + self.random_background = random_background + self.resize_rate = resize_rate + self.num_frames = num_frames + self.size = size + self.device = device + self.batch_size = batch_size + self.model = model + self.sampler = sampler + self.uc = uc + self.camera = camera + self.offset_noise = offset_noise + + @staticmethod + def i2i( + model, + image_size, + prompt, + uc, + sampler, + ip=None, + step=20, + scale=5.0, + batch_size=8, + ddim_eta=0.0, + dtype=torch.float32, + device="cuda", + camera=None, + num_frames=4, + pixel_control=False, + transform=None, + offset_noise=False, + ): + """ The function supports additional image prompt. + Args: + model (_type_): the image dream model + image_size (_type_): size of diffusion output (standard 256) + prompt (_type_): text prompt for the image (prompt in type str) + uc (_type_): unconditional vector (tensor in shape [1, 77, 1024]) + sampler (_type_): imagedream.ldm.models.diffusion.ddim.DDIMSampler + ip (Image, optional): the image prompt. Defaults to None. + step (int, optional): _description_. Defaults to 20. + scale (float, optional): _description_. Defaults to 7.5. + batch_size (int, optional): _description_. Defaults to 8. + ddim_eta (float, optional): _description_. Defaults to 0.0. + dtype (_type_, optional): _description_. Defaults to torch.float32. + device (str, optional): _description_. Defaults to "cuda". + camera (_type_, optional): camera info in tensor, shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00 + num_frames (int, optional): _num of frames (views) to generate + pixel_control: whether to use pixel conditioning. Defaults to False, True when using pixel mode + transform: Compose( + Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=warn) + ToTensor() + Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + ) + """ + ip_raw = ip + if type(prompt) != list: + prompt = [prompt] + with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype): + c = model.get_learned_conditioning(prompt).to( + device + ) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05 + c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size + uc_ = {"context": uc.repeat(batch_size, 1, 1)} + + if camera is not None: + c_["camera"] = uc_["camera"] = ( + camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00 + ) + c_["num_frames"] = uc_["num_frames"] = num_frames + + if ip is not None: + ip_embed = model.get_learned_image_conditioning(ip).to( + device + ) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12 + ip_ = ip_embed.repeat(batch_size, 1, 1) + c_["ip"] = ip_ + uc_["ip"] = torch.zeros_like(ip_) + + if pixel_control: + assert camera is not None + ip = transform(ip).to( + device + ) # shape: torch.Size([3, 256, 256]) mean: 0.33, std: 0.37, min: -1.00, max: 1.00 + ip_img = model.get_first_stage_encoding( + model.encode_first_stage(ip[None, :, :, :]) + ) # shape: torch.Size([1, 4, 32, 32]) mean: 0.23, std: 0.77, min: -4.42, max: 3.55 + c_["ip_img"] = ip_img + uc_["ip_img"] = torch.zeros_like(ip_img) + + shape = [4, image_size // 8, image_size // 8] # [4, 32, 32] + if offset_noise: + ref = transform(ip_raw).to(device) + ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :])) + ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True) + time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device) + x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps) + + samples_ddim, _ = ( + sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43 + S=step, + conditioning=c_, + batch_size=batch_size, + shape=shape, + verbose=False, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc_, + eta=ddim_eta, + x_T=x_T if offset_noise else None, + ) + ) + + x_sample = model.decode_first_stage(samples_ddim) + x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy() + + return list(x_sample.astype(np.uint8)) + + def diffuse(self, t, ip, n_test=2): + set_seed(self.seed) + ip = do_resize_content(ip, self.resize_rate) + if self.random_background: + ip = add_random_background(ip) + + images = [] + for _ in range(n_test): + img = self.i2i( + self.model, + self.size, + t, + self.uc, + self.sampler, + ip=ip, + step=50, + scale=5, + batch_size=self.batch_size, + ddim_eta=0.0, + dtype=self.dtype, + device=self.device, + camera=self.camera, + num_frames=self.num_frames, + pixel_control=(self.mode == "pixel"), + transform=self.image_transform, + offset_noise=self.offset_noise, + ) + img = np.concatenate(img, 1) + img = np.concatenate((img, ip.resize((self.size, self.size))), axis=1) + images.append(img) + set_seed() # unset random and numpy seed + return images + + +class ImageDreamDiffusionStage2: + def __init__( + self, + model, + device, + dtype, + num_frames, + camera_views, + ref_position, + random_background=False, + offset_noise=False, + resize_rate=1, + mode="pixel", + image_size=256, + seed=1234, + ) -> None: + assert mode in ["pixel", "local"] + + size = image_size + self.seed = seed + batch_size = max(4, num_frames) + + neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear." + uc = model.get_learned_conditioning([neg_texts]).to(device) + sampler = DDIMSampler(model) + + # pre-compute camera matrices + camera = [get_camera_for_index(i).squeeze() for i in camera_views] + if ref_position is not None: + camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero + camera = torch.stack(camera) + camera = camera.repeat(batch_size // num_frames, 1).to(device) + + self.image_transform = T.Compose( + [ + T.Resize((size, size)), + T.ToTensor(), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + + self.dtype = dtype + self.mode = mode + self.ref_position = ref_position + self.random_background = random_background + self.resize_rate = resize_rate + self.num_frames = num_frames + self.size = size + self.device = device + self.batch_size = batch_size + self.model = model + self.sampler = sampler + self.uc = uc + self.camera = camera + self.offset_noise = offset_noise + + @staticmethod + def i2iStage2( + model, + image_size, + prompt, + uc, + sampler, + pixel_images, + ip=None, + step=20, + scale=5.0, + batch_size=8, + ddim_eta=0.0, + dtype=torch.float32, + device="cuda", + camera=None, + num_frames=4, + pixel_control=False, + transform=None, + offset_noise=False, + ): + ip_raw = ip + if type(prompt) != list: + prompt = [prompt] + with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype): + c = model.get_learned_conditioning(prompt).to( + device + ) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05 + c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size + uc_ = {"context": uc.repeat(batch_size, 1, 1)} + + if camera is not None: + c_["camera"] = uc_["camera"] = ( + camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00 + ) + c_["num_frames"] = uc_["num_frames"] = num_frames + + if ip is not None: + ip_embed = model.get_learned_image_conditioning(ip).to( + device + ) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12 + ip_ = ip_embed.repeat(batch_size, 1, 1) + c_["ip"] = ip_ + uc_["ip"] = torch.zeros_like(ip_) + + if pixel_control: + assert camera is not None + + transed_pixel_images = torch.stack([transform(i).to(device) for i in pixel_images]) + latent_pixel_images = model.get_first_stage_encoding(model.encode_first_stage(transed_pixel_images)) + + c_["pixel_images"] = latent_pixel_images + uc_["pixel_images"] = torch.zeros_like(latent_pixel_images) + + shape = [4, image_size // 8, image_size // 8] # [4, 32, 32] + if offset_noise: + ref = transform(ip_raw).to(device) + ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :])) + ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True) + time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device) + x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps) + + samples_ddim, _ = ( + sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43 + S=step, + conditioning=c_, + batch_size=batch_size, + shape=shape, + verbose=False, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc_, + eta=ddim_eta, + x_T=x_T if offset_noise else None, + ) + ) + x_sample = model.decode_first_stage(samples_ddim) + x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy() + + return list(x_sample.astype(np.uint8)) + + @torch.no_grad() + def diffuse(self, t, ip, pixel_images, n_test=2): + set_seed(self.seed) + ip = do_resize_content(ip, self.resize_rate) + pixel_images = [do_resize_content(i, self.resize_rate) for i in pixel_images] + + if self.random_background: + bg_color = np.random.rand() * 255 + ip = add_random_background(ip, bg_color) + pixel_images = [add_random_background(i, bg_color) for i in pixel_images] + + images = [] + for _ in range(n_test): + img = self.i2iStage2( + self.model, + self.size, + t, + self.uc, + self.sampler, + pixel_images=pixel_images, + ip=ip, + step=50, + scale=5, + batch_size=self.batch_size, + ddim_eta=0.0, + dtype=self.dtype, + device=self.device, + camera=self.camera, + num_frames=self.num_frames, + pixel_control=(self.mode == "pixel"), + transform=self.image_transform, + offset_noise=self.offset_noise, + ) + img = np.concatenate(img, 1) + img = np.concatenate( + (img, ip.resize((self.size, self.size)), *[i.resize((self.size, self.size)) for i in pixel_images]), + axis=1, + ) + images.append(img) + set_seed() # unset random and numpy seed + return images diff --git a/mesh.py b/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..225a271023da1f5fc3e6a709ce465a6fa91afb5d --- /dev/null +++ b/mesh.py @@ -0,0 +1,845 @@ +import os +import cv2 +import torch +import trimesh +import numpy as np + +from kiui.op import safe_normalize, dot +from kiui.typing import * + +class Mesh: + """ + A torch-native trimesh class, with support for ``ply/obj/glb`` formats. + + Note: + This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture). + """ + def __init__( + self, + v: Optional[Tensor] = None, + f: Optional[Tensor] = None, + vn: Optional[Tensor] = None, + fn: Optional[Tensor] = None, + vt: Optional[Tensor] = None, + ft: Optional[Tensor] = None, + vc: Optional[Tensor] = None, # vertex color + albedo: Optional[Tensor] = None, + metallicRoughness: Optional[Tensor] = None, + device: Optional[torch.device] = None, + ): + """Init a mesh directly using all attributes. + + Args: + v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None. + f (Optional[Tensor]): faces, int [M, 3]. Defaults to None. + vn (Optional[Tensor]): vertex normals, float [N, 3]. Defaults to None. + fn (Optional[Tensor]): faces for normals, int [M, 3]. Defaults to None. + vt (Optional[Tensor]): vertex uv coordinates, float [N, 2]. Defaults to None. + ft (Optional[Tensor]): faces for uvs, int [M, 3]. Defaults to None. + vc (Optional[Tensor]): vertex colors, float [N, 3]. Defaults to None. + albedo (Optional[Tensor]): albedo texture, float [H, W, 3], RGB format. Defaults to None. + metallicRoughness (Optional[Tensor]): metallic-roughness texture, float [H, W, 3], metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]. Defaults to None. + device (Optional[torch.device]): torch device. Defaults to None. + """ + self.device = device + self.v = v + self.vn = vn + self.vt = vt + self.f = f + self.fn = fn + self.ft = ft + # will first see if there is vertex color to use + self.vc = vc + # only support a single albedo image + self.albedo = albedo + # pbr extension, metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1] + # ref: https://registry.khronos.org/glTF/specs/2.0/glTF-2.0.html + self.metallicRoughness = metallicRoughness + + self.ori_center = 0 + self.ori_scale = 1 + + @classmethod + def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs): + """load mesh from path. + + Args: + path (str): path to mesh file, supports ply, obj, glb. + clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False. + resize (bool, optional): auto resize the mesh using ``bound`` into [-bound, bound]^3. Defaults to True. + renormal (bool, optional): re-calc the vertex normals. Defaults to True. + retex (bool, optional): re-calc the uv coordinates, will overwrite the existing uv coordinates. Defaults to False. + bound (float, optional): bound to resize. Defaults to 0.9. + front_dir (str, optional): front-view direction of the mesh, should be [+-][xyz][ 123]. Defaults to '+z'. + device (torch.device, optional): torch device. Defaults to None. + + Note: + a ``device`` keyword argument can be provided to specify the torch device. + If it's not provided, we will try to use ``'cuda'`` as the device if it's available. + + Returns: + Mesh: the loaded Mesh object. + """ + # obj supports face uv + if path.endswith(".obj"): + mesh = cls.load_obj(path, **kwargs) + # trimesh only supports vertex uv, but can load more formats + else: + mesh = cls.load_trimesh(path, **kwargs) + + # clean + if clean: + from kiui.mesh_utils import clean_mesh + vertices = mesh.v.detach().cpu().numpy() + triangles = mesh.f.detach().cpu().numpy() + vertices, triangles = clean_mesh(vertices, triangles, remesh=False) + mesh.v = torch.from_numpy(vertices).contiguous().float().to(mesh.device) + mesh.f = torch.from_numpy(triangles).contiguous().int().to(mesh.device) + + print(f"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}") + # auto-normalize + if resize: + mesh.auto_size(bound=bound) + # auto-fix normal + if renormal or mesh.vn is None: + mesh.auto_normal() + print(f"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}") + # auto-fix texcoords + if retex or (mesh.albedo is not None and mesh.vt is None): + mesh.auto_uv(cache_path=path) + print(f"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}") + + # rotate front dir to +z + if front_dir != "+z": + # axis switch + if "-z" in front_dir: + T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32) + elif "+x" in front_dir: + T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32) + elif "-x" in front_dir: + T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32) + elif "+y" in front_dir: + T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32) + elif "-y" in front_dir: + T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32) + else: + T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) + # rotation (how many 90 degrees) + if '1' in front_dir: + T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) + elif '2' in front_dir: + T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) + elif '3' in front_dir: + T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) + mesh.v @= T + mesh.vn @= T + + return mesh + + # load from obj file + @classmethod + def load_obj(cls, path, albedo_path=None, device=None): + """load an ``obj`` mesh. + + Args: + path (str): path to mesh. + albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None. + device (torch.device, optional): torch device. Defaults to None. + + Note: + We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension. + The `usemtl` statement is ignored, and we only use the last material path in `mtl` file. + + Returns: + Mesh: the loaded Mesh object. + """ + assert os.path.splitext(path)[-1] == ".obj" + + mesh = cls() + + # device + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + mesh.device = device + + # load obj + with open(path, "r") as f: + lines = f.readlines() + + def parse_f_v(fv): + # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided) + # supported forms: + # f v1 v2 v3 + # f v1/vt1 v2/vt2 v3/vt3 + # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3 + # f v1//vn1 v2//vn2 v3//vn3 + xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")] + xs.extend([-1] * (3 - len(xs))) + return xs[0], xs[1], xs[2] + + vertices, texcoords, normals = [], [], [] + faces, tfaces, nfaces = [], [], [] + mtl_path = None + + for line in lines: + split_line = line.split() + # empty line + if len(split_line) == 0: + continue + prefix = split_line[0].lower() + # mtllib + if prefix == "mtllib": + mtl_path = split_line[1] + # usemtl + elif prefix == "usemtl": + pass # ignored + # v/vn/vt + elif prefix == "v": + vertices.append([float(v) for v in split_line[1:]]) + elif prefix == "vn": + normals.append([float(v) for v in split_line[1:]]) + elif prefix == "vt": + val = [float(v) for v in split_line[1:]] + texcoords.append([val[0], 1.0 - val[1]]) + elif prefix == "f": + vs = split_line[1:] + nv = len(vs) + v0, t0, n0 = parse_f_v(vs[0]) + for i in range(nv - 2): # triangulate (assume vertices are ordered) + v1, t1, n1 = parse_f_v(vs[i + 1]) + v2, t2, n2 = parse_f_v(vs[i + 2]) + faces.append([v0, v1, v2]) + tfaces.append([t0, t1, t2]) + nfaces.append([n0, n1, n2]) + + mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device) + mesh.vt = ( + torch.tensor(texcoords, dtype=torch.float32, device=device) + if len(texcoords) > 0 + else None + ) + mesh.vn = ( + torch.tensor(normals, dtype=torch.float32, device=device) + if len(normals) > 0 + else None + ) + + mesh.f = torch.tensor(faces, dtype=torch.int32, device=device) + mesh.ft = ( + torch.tensor(tfaces, dtype=torch.int32, device=device) + if len(texcoords) > 0 + else None + ) + mesh.fn = ( + torch.tensor(nfaces, dtype=torch.int32, device=device) + if len(normals) > 0 + else None + ) + + # see if there is vertex color + use_vertex_color = False + if mesh.v.shape[1] == 6: + use_vertex_color = True + mesh.vc = mesh.v[:, 3:] + mesh.v = mesh.v[:, :3] + print(f"[load_obj] use vertex color: {mesh.vc.shape}") + + # try to load texture image + if not use_vertex_color: + # try to retrieve mtl file + mtl_path_candidates = [] + if mtl_path is not None: + mtl_path_candidates.append(mtl_path) + mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path)) + mtl_path_candidates.append(path.replace(".obj", ".mtl")) + + mtl_path = None + for candidate in mtl_path_candidates: + if os.path.exists(candidate): + mtl_path = candidate + break + + # if albedo_path is not provided, try retrieve it from mtl + metallic_path = None + roughness_path = None + if mtl_path is not None and albedo_path is None: + with open(mtl_path, "r") as f: + lines = f.readlines() + + for line in lines: + split_line = line.split() + # empty line + if len(split_line) == 0: + continue + prefix = split_line[0] + + if "map_Kd" in prefix: + # assume relative path! + albedo_path = os.path.join(os.path.dirname(path), split_line[1]) + print(f"[load_obj] use texture from: {albedo_path}") + elif "map_Pm" in prefix: + metallic_path = os.path.join(os.path.dirname(path), split_line[1]) + elif "map_Pr" in prefix: + roughness_path = os.path.join(os.path.dirname(path), split_line[1]) + + # still not found albedo_path, or the path doesn't exist + if albedo_path is None or not os.path.exists(albedo_path): + # init an empty texture + print(f"[load_obj] init empty albedo!") + # albedo = np.random.rand(1024, 1024, 3).astype(np.float32) + albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) # default color + else: + albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED) + albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB) + albedo = albedo.astype(np.float32) / 255 + print(f"[load_obj] load texture: {albedo.shape}") + + mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device) + + # try to load metallic and roughness + if metallic_path is not None and roughness_path is not None: + print(f"[load_obj] load metallicRoughness from: {metallic_path}, {roughness_path}") + metallic = cv2.imread(metallic_path, cv2.IMREAD_UNCHANGED) + metallic = metallic.astype(np.float32) / 255 + roughness = cv2.imread(roughness_path, cv2.IMREAD_UNCHANGED) + roughness = roughness.astype(np.float32) / 255 + metallicRoughness = np.stack([np.zeros_like(metallic), roughness, metallic], axis=-1) + + mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous() + + return mesh + + @classmethod + def load_trimesh(cls, path, device=None): + """load a mesh using ``trimesh.load()``. + + Can load various formats like ``glb`` and serves as a fallback. + + Note: + We will try to merge all meshes if the glb contains more than one, + but **this may cause the texture to lose**, since we only support one texture image! + + Args: + path (str): path to the mesh file. + device (torch.device, optional): torch device. Defaults to None. + + Returns: + Mesh: the loaded Mesh object. + """ + mesh = cls() + + # device + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + mesh.device = device + + # use trimesh to load ply/glb + _data = trimesh.load(path) + if isinstance(_data, trimesh.Scene): + if len(_data.geometry) == 1: + _mesh = list(_data.geometry.values())[0] + else: + print(f"[load_trimesh] concatenating {len(_data.geometry)} meshes.") + _concat = [] + # loop the scene graph and apply transform to each mesh + scene_graph = _data.graph.to_flattened() # dict {name: {transform: 4x4 mat, geometry: str}} + for k, v in scene_graph.items(): + name = v['geometry'] + if name in _data.geometry and isinstance(_data.geometry[name], trimesh.Trimesh): + transform = v['transform'] + _concat.append(_data.geometry[name].apply_transform(transform)) + _mesh = trimesh.util.concatenate(_concat) + else: + _mesh = _data + + if _mesh.visual.kind == 'vertex': + vertex_colors = _mesh.visual.vertex_colors + vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255 + mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device) + print(f"[load_trimesh] use vertex color: {mesh.vc.shape}") + elif _mesh.visual.kind == 'texture': + _material = _mesh.visual.material + if isinstance(_material, trimesh.visual.material.PBRMaterial): + texture = np.array(_material.baseColorTexture).astype(np.float32) / 255 + # load metallicRoughness if present + if _material.metallicRoughnessTexture is not None: + metallicRoughness = np.array(_material.metallicRoughnessTexture).astype(np.float32) / 255 + mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous() + elif isinstance(_material, trimesh.visual.material.SimpleMaterial): + texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255 + else: + raise NotImplementedError(f"material type {type(_material)} not supported!") + mesh.albedo = torch.tensor(texture[..., :3], dtype=torch.float32, device=device).contiguous() + print(f"[load_trimesh] load texture: {texture.shape}") + else: + texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) + mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device) + print(f"[load_trimesh] failed to load texture.") + + vertices = _mesh.vertices + + try: + texcoords = _mesh.visual.uv + texcoords[:, 1] = 1 - texcoords[:, 1] + except Exception as e: + texcoords = None + + try: + normals = _mesh.vertex_normals + except Exception as e: + normals = None + + # trimesh only support vertex uv... + faces = tfaces = nfaces = _mesh.faces + + mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device) + mesh.vt = ( + torch.tensor(texcoords, dtype=torch.float32, device=device) + if texcoords is not None + else None + ) + mesh.vn = ( + torch.tensor(normals, dtype=torch.float32, device=device) + if normals is not None + else None + ) + + mesh.f = torch.tensor(faces, dtype=torch.int32, device=device) + mesh.ft = ( + torch.tensor(tfaces, dtype=torch.int32, device=device) + if texcoords is not None + else None + ) + mesh.fn = ( + torch.tensor(nfaces, dtype=torch.int32, device=device) + if normals is not None + else None + ) + + return mesh + + # sample surface (using trimesh) + def sample_surface(self, count: int): + """sample points on the surface of the mesh. + + Args: + count (int): number of points to sample. + + Returns: + torch.Tensor: the sampled points, float [count, 3]. + """ + _mesh = trimesh.Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.f.detach().cpu().numpy()) + points, face_idx = trimesh.sample.sample_surface(_mesh, count) + points = torch.from_numpy(points).float().to(self.device) + return points + + # aabb + def aabb(self): + """get the axis-aligned bounding box of the mesh. + + Returns: + Tuple[torch.Tensor]: the min xyz and max xyz of the mesh. + """ + return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values + + # unit size + @torch.no_grad() + def auto_size(self, bound=0.9): + """auto resize the mesh. + + Args: + bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9. + """ + vmin, vmax = self.aabb() + self.ori_center = (vmax + vmin) / 2 + self.ori_scale = 2 * bound / torch.max(vmax - vmin).item() + self.v = (self.v - self.ori_center) * self.ori_scale + + def auto_normal(self): + """auto calculate the vertex normals. + """ + i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long() + v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + vn = torch.zeros_like(self.v) + vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) + vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) + vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + vn = torch.where( + dot(vn, vn) > 1e-20, + vn, + torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device), + ) + vn = safe_normalize(vn) + + self.vn = vn + self.fn = self.f + + def auto_uv(self, cache_path=None, vmap=True): + """auto calculate the uv coordinates. + + Args: + cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None. + vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf). + Usually this will duplicate the vertices on the edge of uv atlas. Defaults to True. + """ + # try to load cache + if cache_path is not None: + cache_path = os.path.splitext(cache_path)[0] + "_uv.npz" + if cache_path is not None and os.path.exists(cache_path): + data = np.load(cache_path) + vt_np, ft_np, vmapping = data["vt"], data["ft"], data["vmapping"] + else: + import xatlas + + v_np = self.v.detach().cpu().numpy() + f_np = self.f.detach().int().cpu().numpy() + atlas = xatlas.Atlas() + atlas.add_mesh(v_np, f_np) + chart_options = xatlas.ChartOptions() + # chart_options.max_iterations = 4 + atlas.generate(chart_options=chart_options) + vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2] + + # save to cache + if cache_path is not None: + np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping) + + vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device) + ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device) + self.vt = vt + self.ft = ft + + if vmap: + vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device) + self.align_v_to_vt(vmapping) + + def align_v_to_vt(self, vmapping=None): + """ remap v/f and vn/fn to vt/ft. + + Args: + vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None. + """ + if vmapping is None: + ft = self.ft.view(-1).long() + f = self.f.view(-1).long() + vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device) + vmapping[ft] = f # scatter, randomly choose one if index is not unique + + self.v = self.v[vmapping] + self.f = self.ft + + if self.vn is not None: + self.vn = self.vn[vmapping] + self.fn = self.ft + + def to(self, device): + """move all tensor attributes to device. + + Args: + device (torch.device): target device. + + Returns: + Mesh: self. + """ + self.device = device + for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo", "vc", "metallicRoughness"]: + tensor = getattr(self, name) + if tensor is not None: + setattr(self, name, tensor.to(device)) + return self + + def write(self, path): + """write the mesh to a path. + + Args: + path (str): path to write, supports ply, obj and glb. + """ + if path.endswith(".ply"): + self.write_ply(path) + elif path.endswith(".obj"): + self.write_obj(path) + elif path.endswith(".glb") or path.endswith(".gltf"): + self.write_glb(path) + else: + raise NotImplementedError(f"format {path} not supported!") + + def write_ply(self, path): + """write the mesh in ply format. Only for geometry! + + Args: + path (str): path to write. + """ + + if self.albedo is not None: + print(f'[WARN] ply format does not support exporting texture, will ignore!') + + v_np = self.v.detach().cpu().numpy() + f_np = self.f.detach().cpu().numpy() + + _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np) + _mesh.export(path) + + + def write_glb(self, path): + """write the mesh in glb/gltf format. + This will create a scene with a single mesh. + + Args: + path (str): path to write. + """ + + # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0] + if self.vt is not None and self.v.shape[0] != self.vt.shape[0]: + self.align_v_to_vt() + + import pygltflib + + f_np = self.f.detach().cpu().numpy().astype(np.uint32) + f_np_blob = f_np.flatten().tobytes() + + v_np = self.v.detach().cpu().numpy().astype(np.float32) + v_np_blob = v_np.tobytes() + + blob = f_np_blob + v_np_blob + byteOffset = len(blob) + + # base mesh + gltf = pygltflib.GLTF2( + scene=0, + scenes=[pygltflib.Scene(nodes=[0])], + nodes=[pygltflib.Node(mesh=0)], + meshes=[pygltflib.Mesh(primitives=[pygltflib.Primitive( + # indices to accessors (0 is triangles) + attributes=pygltflib.Attributes( + POSITION=1, + ), + indices=0, + )])], + buffers=[ + pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob)) + ], + # buffer view (based on dtype) + bufferViews=[ + # triangles; as flatten (element) array + pygltflib.BufferView( + buffer=0, + byteLength=len(f_np_blob), + target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963) + ), + # positions; as vec3 array + pygltflib.BufferView( + buffer=0, + byteOffset=len(f_np_blob), + byteLength=len(v_np_blob), + byteStride=12, # vec3 + target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962) + ), + ], + accessors=[ + # 0 = triangles + pygltflib.Accessor( + bufferView=0, + componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125) + count=f_np.size, + type=pygltflib.SCALAR, + max=[int(f_np.max())], + min=[int(f_np.min())], + ), + # 1 = positions + pygltflib.Accessor( + bufferView=1, + componentType=pygltflib.FLOAT, # GL_FLOAT (5126) + count=len(v_np), + type=pygltflib.VEC3, + max=v_np.max(axis=0).tolist(), + min=v_np.min(axis=0).tolist(), + ), + ], + ) + + # append texture info + if self.vt is not None: + + vt_np = self.vt.detach().cpu().numpy().astype(np.float32) + vt_np_blob = vt_np.tobytes() + + albedo = self.albedo.detach().cpu().numpy() + albedo = (albedo * 255).astype(np.uint8) + albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR) + albedo_blob = cv2.imencode('.png', albedo)[1].tobytes() + + # update primitive + gltf.meshes[0].primitives[0].attributes.TEXCOORD_0 = 2 + gltf.meshes[0].primitives[0].material = 0 + + # update materials + gltf.materials.append(pygltflib.Material( + pbrMetallicRoughness=pygltflib.PbrMetallicRoughness( + baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0), + metallicFactor=0.0, + roughnessFactor=1.0, + ), + alphaMode=pygltflib.OPAQUE, + alphaCutoff=None, + doubleSided=True, + )) + + gltf.textures.append(pygltflib.Texture(sampler=0, source=0)) + gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT)) + gltf.images.append(pygltflib.Image(bufferView=3, mimeType="image/png")) + + # update buffers + gltf.bufferViews.append( + # index = 2, texcoords; as vec2 array + pygltflib.BufferView( + buffer=0, + byteOffset=byteOffset, + byteLength=len(vt_np_blob), + byteStride=8, # vec2 + target=pygltflib.ARRAY_BUFFER, + ) + ) + + gltf.accessors.append( + # 2 = texcoords + pygltflib.Accessor( + bufferView=2, + componentType=pygltflib.FLOAT, + count=len(vt_np), + type=pygltflib.VEC2, + max=vt_np.max(axis=0).tolist(), + min=vt_np.min(axis=0).tolist(), + ) + ) + + blob += vt_np_blob + byteOffset += len(vt_np_blob) + + gltf.bufferViews.append( + # index = 3, albedo texture; as none target + pygltflib.BufferView( + buffer=0, + byteOffset=byteOffset, + byteLength=len(albedo_blob), + ) + ) + + blob += albedo_blob + byteOffset += len(albedo_blob) + + gltf.buffers[0].byteLength = byteOffset + + # append metllic roughness + if self.metallicRoughness is not None: + metallicRoughness = self.metallicRoughness.detach().cpu().numpy() + metallicRoughness = (metallicRoughness * 255).astype(np.uint8) + metallicRoughness = cv2.cvtColor(metallicRoughness, cv2.COLOR_RGB2BGR) + metallicRoughness_blob = cv2.imencode('.png', metallicRoughness)[1].tobytes() + + # update texture definition + gltf.materials[0].pbrMetallicRoughness.metallicFactor = 1.0 + gltf.materials[0].pbrMetallicRoughness.roughnessFactor = 1.0 + gltf.materials[0].pbrMetallicRoughness.metallicRoughnessTexture = pygltflib.TextureInfo(index=1, texCoord=0) + + gltf.textures.append(pygltflib.Texture(sampler=1, source=1)) + gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT)) + gltf.images.append(pygltflib.Image(bufferView=4, mimeType="image/png")) + + # update buffers + gltf.bufferViews.append( + # index = 4, metallicRoughness texture; as none target + pygltflib.BufferView( + buffer=0, + byteOffset=byteOffset, + byteLength=len(metallicRoughness_blob), + ) + ) + + blob += metallicRoughness_blob + byteOffset += len(metallicRoughness_blob) + + gltf.buffers[0].byteLength = byteOffset + + + # set actual data + gltf.set_binary_blob(blob) + + # glb = b"".join(gltf.save_to_bytes()) + gltf.save(path) + + + def write_obj(self, path): + """write the mesh in obj format. Will also write the texture and mtl files. + + Args: + path (str): path to write. + """ + + mtl_path = path.replace(".obj", ".mtl") + albedo_path = path.replace(".obj", "_albedo.png") + metallic_path = path.replace(".obj", "_metallic.png") + roughness_path = path.replace(".obj", "_roughness.png") + + v_np = self.v.detach().cpu().numpy() + vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None + vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None + f_np = self.f.detach().cpu().numpy() + ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None + fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None + + with open(path, "w") as fp: + fp.write(f"mtllib {os.path.basename(mtl_path)} \n") + + for v in v_np: + fp.write(f"v {v[0]} {v[1]} {v[2]} \n") + + if vt_np is not None: + for v in vt_np: + fp.write(f"vt {v[0]} {1 - v[1]} \n") + + if vn_np is not None: + for v in vn_np: + fp.write(f"vn {v[0]} {v[1]} {v[2]} \n") + + fp.write(f"usemtl defaultMat \n") + for i in range(len(f_np)): + fp.write( + f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \ + {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \ + {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n' + ) + + with open(mtl_path, "w") as fp: + fp.write(f"newmtl defaultMat \n") + fp.write(f"Ka 1 1 1 \n") + fp.write(f"Kd 1 1 1 \n") + fp.write(f"Ks 0 0 0 \n") + fp.write(f"Tr 1 \n") + fp.write(f"illum 1 \n") + fp.write(f"Ns 0 \n") + if self.albedo is not None: + fp.write(f"map_Kd {os.path.basename(albedo_path)} \n") + if self.metallicRoughness is not None: + # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering + fp.write(f"map_Pm {os.path.basename(metallic_path)} \n") + fp.write(f"map_Pr {os.path.basename(roughness_path)} \n") + + if self.albedo is not None: + albedo = self.albedo.detach().cpu().numpy() + albedo = (albedo * 255).astype(np.uint8) + cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)) + + if self.metallicRoughness is not None: + metallicRoughness = self.metallicRoughness.detach().cpu().numpy() + metallicRoughness = (metallicRoughness * 255).astype(np.uint8) + cv2.imwrite(metallic_path, metallicRoughness[..., 2]) + cv2.imwrite(roughness_path, metallicRoughness[..., 1]) + diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b339e3ea9dac5482a6daed63b069e4d1eda000a8 --- /dev/null +++ b/model/__init__.py @@ -0,0 +1 @@ +from model.crm.model import CRM \ No newline at end of file diff --git a/model/archs/__init__.py b/model/archs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/archs/decoders/__init__.py b/model/archs/decoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/model/archs/decoders/__init__.py @@ -0,0 +1 @@ + diff --git a/model/archs/decoders/shape_texture_net.py b/model/archs/decoders/shape_texture_net.py new file mode 100644 index 0000000000000000000000000000000000000000..706d95cb82e9295390a848f4534119fa735f52a7 --- /dev/null +++ b/model/archs/decoders/shape_texture_net.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TetTexNet(nn.Module): + def __init__(self, plane_reso=64, padding=0.1, fea_concat=True): + super().__init__() + # self.c_dim = c_dim + self.plane_reso = plane_reso + self.padding = padding + self.fea_concat = fea_concat + + def forward(self, rolled_out_feature, query): + # rolled_out_feature: rolled-out triplane feature + # query: queried xyz coordinates (should be scaled consistently to ptr cloud) + + plane_reso = self.plane_reso + + triplane_feature = dict() + triplane_feature['xy'] = rolled_out_feature[:, :, :, 0: plane_reso] + triplane_feature['yz'] = rolled_out_feature[:, :, :, plane_reso: 2 * plane_reso] + triplane_feature['zx'] = rolled_out_feature[:, :, :, 2 * plane_reso:] + + query_feature_xy = self.sample_plane_feature(query, triplane_feature['xy'], 'xy') + query_feature_yz = self.sample_plane_feature(query, triplane_feature['yz'], 'yz') + query_feature_zx = self.sample_plane_feature(query, triplane_feature['zx'], 'zx') + + if self.fea_concat: + query_feature = torch.cat((query_feature_xy, query_feature_yz, query_feature_zx), dim=1) + else: + query_feature = query_feature_xy + query_feature_yz + query_feature_zx + + output = query_feature.permute(0, 2, 1) + + return output + + # uses values from plane_feature and pixel locations from vgrid to interpolate feature + def sample_plane_feature(self, query, plane_feature, plane): + # CYF note: + # for pretraining, query are uniformly sampled positions w.i. [-scale, scale] + # for training, query are essentially tetrahedra grid vertices, which are + # also within [-scale, scale] in the current version! + # xy range [-scale, scale] + if plane == 'xy': + xy = query[:, :, [0, 1]] + elif plane == 'yz': + xy = query[:, :, [1, 2]] + elif plane == 'zx': + xy = query[:, :, [2, 0]] + else: + raise ValueError("Error! Invalid plane type!") + + xy = xy[:, :, None].float() + # not seem necessary to rescale the grid, because from + # https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html, + # it specifies sampling locations normalized by plane_feature's spatial dimension, + # which is within [-scale, scale] as specified by encoder's calling of coordinate2index() + vgrid = 1.0 * xy + sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1) + + return sampled_feat diff --git a/model/archs/mlp_head.py b/model/archs/mlp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7b87ff5b380bdd9b6e2f73cde42862080562eec3 --- /dev/null +++ b/model/archs/mlp_head.py @@ -0,0 +1,40 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class SdfMlp(nn.Module): + def __init__(self, input_dim, hidden_dim=512, bias=True): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias) + self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=bias) + self.fc3 = nn.Linear(hidden_dim, 4, bias=bias) + + + def forward(self, input): + x = F.relu(self.fc1(input)) + x = F.relu(self.fc2(x)) + out = self.fc3(x) + return out + + +class RgbMlp(nn.Module): + def __init__(self, input_dim, hidden_dim=512, bias=True): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias) + self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=bias) + self.fc3 = nn.Linear(hidden_dim, 3, bias=bias) + + def forward(self, input): + x = F.relu(self.fc1(input)) + x = F.relu(self.fc2(x)) + out = self.fc3(x) + + return out + + \ No newline at end of file diff --git a/model/archs/unet.py b/model/archs/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..e43b72b45e1aba8ced1801c664f4afbcd7556682 --- /dev/null +++ b/model/archs/unet.py @@ -0,0 +1,53 @@ +''' +Codes are from: +https://github.com/jaxony/unet-pytorch/blob/master/model.py +''' + +import torch +import torch.nn as nn +from diffusers import UNet2DModel +import einops +class UNetPP(nn.Module): + ''' + Wrapper for UNet in diffusers + ''' + def __init__(self, in_channels): + super(UNetPP, self).__init__() + self.in_channels = in_channels + self.unet = UNet2DModel( + sample_size=[256, 256*3], + in_channels=in_channels, + out_channels=32, + layers_per_block=2, + block_out_channels=(64, 128, 128, 128*2, 128*2, 128*4, 128*4), + down_block_types=( + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", + "AttnUpBlock2D", + "AttnUpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), + ) + + self.unet.enable_xformers_memory_efficient_attention() + if in_channels > 12: + self.learned_plane = torch.nn.parameter.Parameter(torch.zeros([1,in_channels-12,256,256*3])) + + def forward(self, x, t=256): + learned_plane = self.learned_plane + if x.shape[1] < self.in_channels: + learned_plane = einops.repeat(learned_plane, '1 C H W -> B C H W', B=x.shape[0]).to(x.device) + x = torch.cat([x, learned_plane], dim = 1) + return self.unet(x, t).sample + diff --git a/model/crm/model.py b/model/crm/model.py new file mode 100644 index 0000000000000000000000000000000000000000..70a309283aeec7d976ed5c5f60de36c1710f51c4 --- /dev/null +++ b/model/crm/model.py @@ -0,0 +1,213 @@ +import torch.nn as nn +import torch +import torch.nn.functional as F + +import numpy as np + + +from pathlib import Path +import cv2 +import trimesh +import nvdiffrast.torch as dr + +from model.archs.decoders.shape_texture_net import TetTexNet +from model.archs.unet import UNetPP +from util.renderer import Renderer +from model.archs.mlp_head import SdfMlp, RgbMlp +import xatlas + + +class Dummy: + pass + +class CRM(nn.Module): + def __init__(self, specs): + super(CRM, self).__init__() + + self.specs = specs + # configs + input_specs = specs["Input"] + self.input = Dummy() + self.input.scale = input_specs['scale'] + self.input.resolution = input_specs['resolution'] + self.tet_grid_size = input_specs['tet_grid_size'] + self.camera_angle_num = input_specs['camera_angle_num'] + + self.arch = Dummy() + self.arch.fea_concat = specs["ArchSpecs"]["fea_concat"] + self.arch.mlp_bias = specs["ArchSpecs"]["mlp_bias"] + + self.dec = Dummy() + self.dec.c_dim = specs["DecoderSpecs"]["c_dim"] + self.dec.plane_resolution = specs["DecoderSpecs"]["plane_resolution"] + + self.geo_type = specs["Train"].get("geo_type", "flex") # "dmtet" or "flex" + + self.unet2 = UNetPP(in_channels=self.dec.c_dim) + + mlp_chnl_s = 3 if self.arch.fea_concat else 1 # 3 for queried triplane feature concatenation + self.decoder = TetTexNet(plane_reso=self.dec.plane_resolution, fea_concat=self.arch.fea_concat) + + if self.geo_type == "flex": + self.weightMlp = nn.Sequential( + nn.Linear(mlp_chnl_s * 32 * 8, 512), + nn.SiLU(), + nn.Linear(512, 21)) + + self.sdfMlp = SdfMlp(mlp_chnl_s * 32, 512, bias=self.arch.mlp_bias) + self.rgbMlp = RgbMlp(mlp_chnl_s * 32, 512, bias=self.arch.mlp_bias) + self.renderer = Renderer(tet_grid_size=self.tet_grid_size, camera_angle_num=self.camera_angle_num, + scale=self.input.scale, geo_type = self.geo_type) + + + self.spob = True if specs['Pretrain']['mode'] is None else False # whether to add sphere + self.radius = specs['Pretrain']['radius'] # used when spob + + self.denoising = True + from diffusers import DDIMScheduler + self.scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler") + + def decode(self, data, triplane_feature2): + if self.geo_type == "flex": + tet_verts = self.renderer.flexicubes.verts.unsqueeze(0) + tet_indices = self.renderer.flexicubes.indices + + dec_verts = self.decoder(triplane_feature2, tet_verts) + out = self.sdfMlp(dec_verts) + + weight = None + if self.geo_type == "flex": + grid_feat = torch.index_select(input=dec_verts, index=self.renderer.flexicubes.indices.reshape(-1),dim=1) + grid_feat = grid_feat.reshape(dec_verts.shape[0], self.renderer.flexicubes.indices.shape[0], self.renderer.flexicubes.indices.shape[1] * dec_verts.shape[-1]) + weight = self.weightMlp(grid_feat) + weight = weight * 0.1 + + pred_sdf, deformation = out[..., 0], out[..., 1:] + if self.spob: + pred_sdf = pred_sdf + self.radius - torch.sqrt((tet_verts**2).sum(-1)) + + _, verts, faces = self.renderer(data, pred_sdf, deformation, tet_verts, tet_indices, weight= weight) + return verts[0].unsqueeze(0), faces[0].int() + + def export_mesh(self, data, out_dir, ind, device=None, tri_fea_2 = None): + verts = data['verts'] + faces = data['faces'] + + dec_verts = self.decoder(tri_fea_2, verts.unsqueeze(0)) + colors = self.rgbMlp(dec_verts).squeeze().detach().cpu().numpy() + # Expect predicted colors value range from [-1, 1] + colors = (colors * 0.5 + 0.5).clip(0, 1) + + verts = verts.squeeze().cpu().numpy() + faces = faces[..., [2, 1, 0]].squeeze().cpu().numpy() + + # export the final mesh + with torch.no_grad(): + mesh = trimesh.Trimesh(verts, faces, vertex_colors=colors, process=False) # important, process=True leads to seg fault... + mesh.export(out_dir / f'{ind}.obj') + + def export_mesh_wt_uv(self, ctx, data, out_dir, ind, device, res, tri_fea_2=None): + + mesh_v = data['verts'].squeeze().cpu().numpy() + mesh_pos_idx = data['faces'].squeeze().cpu().numpy() + + def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, + diff_attrs=None if rast_db is None else 'all') + + vmapping, indices, uvs = xatlas.parametrize(mesh_v, mesh_pos_idx) + + mesh_v = torch.tensor(mesh_v, dtype=torch.float32, device=device) + mesh_pos_idx = torch.tensor(mesh_pos_idx, dtype=torch.int64, device=device) + + # Convert to tensors + indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) + + uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) + mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) + # mesh_v_tex. ture + uv_clip = uvs[None, ...] * 2.0 - 1.0 + + # pad to four component coordinate + uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) + + # rasterize + rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), res) + + # Interpolate world space position + gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) + mask = rast[..., 3:4] > 0 + + # return uvs, mesh_tex_idx, gb_pos, mask + gb_pos_unsqz = gb_pos.view(-1, 3) + mask_unsqz = mask.view(-1) + tex_unsqz = torch.zeros_like(gb_pos_unsqz) + 1 + + gb_mask_pos = gb_pos_unsqz[mask_unsqz] + + gb_mask_pos = gb_mask_pos[None, ] + + with torch.no_grad(): + + dec_verts = self.decoder(tri_fea_2, gb_mask_pos) + colors = self.rgbMlp(dec_verts).squeeze() + + # Expect predicted colors value range from [-1, 1] + lo, hi = (-1, 1) + colors = (colors - lo) * (255 / (hi - lo)) + colors = colors.clip(0, 255) + + tex_unsqz[mask_unsqz] = colors + + tex = tex_unsqz.view(res + (3,)) + + verts = mesh_v.squeeze().cpu().numpy() + faces = mesh_pos_idx[..., [2, 1, 0]].squeeze().cpu().numpy() + # faces = mesh_pos_idx + # faces = faces.detach().cpu().numpy() + # faces = faces[..., [2, 1, 0]] + indices = indices[..., [2, 1, 0]] + + # xatlas.export(f"{out_dir}/{ind}.obj", verts[vmapping], indices, uvs) + matname = f'{out_dir}.mtl' + # matname = f'{out_dir}/{ind}.mtl' + fid = open(matname, 'w') + fid.write('newmtl material_0\n') + fid.write('Kd 1 1 1\n') + fid.write('Ka 1 1 1\n') + # fid.write('Ks 0 0 0\n') + fid.write('Ks 0.4 0.4 0.4\n') + fid.write('Ns 10\n') + fid.write('illum 2\n') + fid.write(f'map_Kd {out_dir.split("/")[-1]}.png\n') + fid.close() + + fid = open(f'{out_dir}.obj', 'w') + # fid = open(f'{out_dir}/{ind}.obj', 'w') + fid.write('mtllib %s.mtl\n' % out_dir.split("/")[-1]) + + for pidx, p in enumerate(verts): + pp = p + fid.write('v %f %f %f\n' % (pp[0], pp[2], - pp[1])) + + for pidx, p in enumerate(uvs): + pp = p + fid.write('vt %f %f\n' % (pp[0], 1 - pp[1])) + + fid.write('usemtl material_0\n') + for i, f in enumerate(faces): + f1 = f + 1 + f2 = indices[i] + 1 + fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) + fid.close() + + img = np.asarray(tex.data.cpu().numpy(), dtype=np.float32) + mask = np.sum(img.astype(float), axis=-1, keepdims=True) + mask = (mask <= 3.0).astype(float) + kernel = np.ones((3, 3), 'uint8') + dilate_img = cv2.dilate(img, kernel, iterations=1) + img = img * (1 - mask) + dilate_img * mask + img = img.clip(0, 255).astype(np.uint8) + + cv2.imwrite(f'{out_dir}.png', img[..., [2, 1, 0]]) + # cv2.imwrite(f'{out_dir}/{ind}.png', img[..., [2, 1, 0]]) diff --git a/pipelines.py b/pipelines.py new file mode 100644 index 0000000000000000000000000000000000000000..370eff5954f1850b18f626c487cbc55ad8ed98c2 --- /dev/null +++ b/pipelines.py @@ -0,0 +1,170 @@ +import torch +from libs.base_utils import do_resize_content +from imagedream.ldm.util import ( + instantiate_from_config, + get_obj_from_str, +) +from omegaconf import OmegaConf +from PIL import Image +import numpy as np + + +class TwoStagePipeline(object): + def __init__( + self, + stage1_model_config, + stage2_model_config, + stage1_sampler_config, + stage2_sampler_config, + device="cuda", + dtype=torch.float16, + resize_rate=1, + ) -> None: + """ + only for two stage generate process. + - the first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config + - the second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config + """ + self.resize_rate = resize_rate + + self.stage1_model = instantiate_from_config(OmegaConf.load(stage1_model_config.config).model) + self.stage1_model.load_state_dict(torch.load(stage1_model_config.resume, map_location="cpu"), strict=False) + self.stage1_model = self.stage1_model.to(device).to(dtype) + + self.stage2_model = instantiate_from_config(OmegaConf.load(stage2_model_config.config).model) + sd = torch.load(stage2_model_config.resume, map_location="cpu") + self.stage2_model.load_state_dict(sd, strict=False) + self.stage2_model = self.stage2_model.to(device).to(dtype) + + self.stage1_model.device = device + self.stage2_model.device = device + self.device = device + self.dtype = dtype + self.stage1_sampler = get_obj_from_str(stage1_sampler_config.target)( + self.stage1_model, device=device, dtype=dtype, **stage1_sampler_config.params + ) + self.stage2_sampler = get_obj_from_str(stage2_sampler_config.target)( + self.stage2_model, device=device, dtype=dtype, **stage2_sampler_config.params + ) + + def stage1_sample( + self, + pixel_img, + prompt="3D assets", + neg_texts="uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear.", + step=50, + scale=5, + ddim_eta=0.0, + ): + if type(pixel_img) == str: + pixel_img = Image.open(pixel_img) + + if isinstance(pixel_img, Image.Image): + if pixel_img.mode == "RGBA": + background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0)) + pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB") + else: + pixel_img = pixel_img.convert("RGB") + else: + raise + uc = self.stage1_sampler.model.get_learned_conditioning([neg_texts]).to(self.device) + stage1_images = self.stage1_sampler.i2i( + self.stage1_sampler.model, + self.stage1_sampler.size, + prompt, + uc=uc, + sampler=self.stage1_sampler.sampler, + ip=pixel_img, + step=step, + scale=scale, + batch_size=self.stage1_sampler.batch_size, + ddim_eta=ddim_eta, + dtype=self.stage1_sampler.dtype, + device=self.stage1_sampler.device, + camera=self.stage1_sampler.camera, + num_frames=self.stage1_sampler.num_frames, + pixel_control=(self.stage1_sampler.mode == "pixel"), + transform=self.stage1_sampler.image_transform, + offset_noise=self.stage1_sampler.offset_noise, + ) + + stage1_images = [Image.fromarray(img) for img in stage1_images] + stage1_images.pop(self.stage1_sampler.ref_position) + return stage1_images + + def stage2_sample(self, pixel_img, stage1_images): + if type(pixel_img) == str: + pixel_img = Image.open(pixel_img) + + if isinstance(pixel_img, Image.Image): + if pixel_img.mode == "RGBA": + background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0)) + pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB") + else: + pixel_img = pixel_img.convert("RGB") + else: + raise + stage2_images = self.stage2_sampler.i2iStage2( + self.stage2_sampler.model, + self.stage2_sampler.size, + "3D assets", + self.stage2_sampler.uc, + self.stage2_sampler.sampler, + pixel_images=stage1_images, + ip=pixel_img, + step=50, + scale=5, + batch_size=self.stage2_sampler.batch_size, + ddim_eta=0.0, + dtype=self.stage2_sampler.dtype, + device=self.stage2_sampler.device, + camera=self.stage2_sampler.camera, + num_frames=self.stage2_sampler.num_frames, + pixel_control=(self.stage2_sampler.mode == "pixel"), + transform=self.stage2_sampler.image_transform, + offset_noise=self.stage2_sampler.offset_noise, + ) + stage2_images = [Image.fromarray(img) for img in stage2_images] + return stage2_images + + def set_seed(self, seed): + self.stage1_sampler.seed = seed + self.stage2_sampler.seed = seed + + def __call__(self, pixel_img, prompt="3D assets", scale=5, step=50): + pixel_img = do_resize_content(pixel_img, self.resize_rate) + stage1_images = self.stage1_sample(pixel_img, prompt, scale=scale, step=step) + stage2_images = self.stage2_sample(pixel_img, stage1_images) + + return { + "ref_img": pixel_img, + "stage1_images": stage1_images, + "stage2_images": stage2_images, + } + + +if __name__ == "__main__": + + stage1_config = OmegaConf.load("configs/nf7_v3_SNR_rd_size_stroke.yaml").config + stage2_config = OmegaConf.load("configs/stage2-v2-snr.yaml").config + stage2_sampler_config = stage2_config.sampler + stage1_sampler_config = stage1_config.sampler + + stage1_model_config = stage1_config.models + stage2_model_config = stage2_config.models + + pipeline = TwoStagePipeline( + stage1_model_config, + stage2_model_config, + stage1_sampler_config, + stage2_sampler_config, + ) + + img = Image.open("assets/astronaut.png") + rt_dict = pipeline(img) + stage1_images = rt_dict["stage1_images"] + stage2_images = rt_dict["stage2_images"] + np_imgs = np.concatenate(stage1_images, 1) + np_xyzs = np.concatenate(stage2_images, 1) + Image.fromarray(np_imgs).save("pixel_images.png") + Image.fromarray(np_xyzs).save("xyz_images.png") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ae91e22851faea6eeb50236a90403642a84d7fd1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +gradio +huggingface-hub +einops==0.7.0 +Pillow==10.1.0 +transformers==4.27.1 +open-clip-torch==2.7.0 +opencv-contrib-python-headless==4.9.0.80 +opencv-python-headless==4.9.0.80 +xformers +omegaconf +rembg +nvdiffrast +pygltflib +kiui \ No newline at end of file diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/util/flexicubes.py b/util/flexicubes.py new file mode 100644 index 0000000000000000000000000000000000000000..1a22d63cafc1f611e09b9a684041abb29ccbcaec --- /dev/null +++ b/util/flexicubes.py @@ -0,0 +1,579 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch +from util.tables import * + +__all__ = [ + 'FlexiCubes' +] + + +class FlexiCubes: + """ + This class implements the FlexiCubes method for extracting meshes from scalar fields. + It maintains a series of lookup tables and indices to support the mesh extraction process. + FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances + the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting + the surface representation through gradient-based optimization. + + During instantiation, the class loads DMC tables from a file and transforms them into + PyTorch tensors on the specified device. + + Attributes: + device (str): Specifies the computational device (default is "cuda"). + dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges + associated with each dual vertex in 256 Marching Cubes (MC) configurations. + num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of + the 256 MC configurations. + check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19 + of the DMC configurations. + tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface. + quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles + along one diagonal. + quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into + two triangles along the other diagonal. + quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles + during training by connecting all edges to their midpoints. + cube_corners (torch.Tensor): Defines the positions of a standard unit cube's + eight corners in 3D space, ordered starting from the origin (0,0,0), + moving along the x-axis, then y-axis, and finally z-axis. + Used as a blueprint for generating a voxel grid. + cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used + to retrieve the case id. + cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs. + Used to retrieve edge vertices in DMC. + edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with + their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the + first edge is oriented along the x-axis. + dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges + across four adjacent cubes to the shared faces of these cubes. For instance, + dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along + the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively. + This tensor is only utilized during isosurface tetrahedralization. + adj_pairs (torch.Tensor): + A tensor containing index pairs that correspond to neighboring cubes that share the same edge. + qef_reg_scale (float): + The scaling factor applied to the regularization loss to prevent issues with singularity + when solving the QEF. This parameter is only used when a 'grad_func' is specified. + weight_scale (float): + The scale of weights in FlexiCubes. Should be between 0 and 1. + """ + + def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99): + + self.device = device + self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False) + self.num_vd_table = torch.tensor(num_vd_table, + dtype=torch.long, device=device, requires_grad=False) + self.check_table = torch.tensor( + check_table, + dtype=torch.long, device=device, requires_grad=False) + + self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False) + self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_train = torch.tensor( + [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False) + + self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ + 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device) + self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) + self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, + 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False) + + self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], + dtype=torch.long, device=device) + self.dir_faces_table = torch.tensor([ + [[5, 4], [3, 2], [4, 5], [2, 3]], + [[5, 4], [1, 0], [4, 5], [0, 1]], + [[3, 2], [1, 0], [2, 3], [0, 1]] + ], dtype=torch.long, device=device) + self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device) + self.qef_reg_scale = qef_reg_scale + self.weight_scale = weight_scale + + def construct_voxel_grid(self, res): + """ + Generates a voxel grid based on the specified resolution. + + Args: + res (int or list[int]): The resolution of the voxel grid. If an integer + is provided, it is used for all three dimensions. If a list or tuple + of 3 integers is provided, they define the resolution for the x, + y, and z dimensions respectively. + + Returns: + (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the + cube corners (index into vertices) of the constructed voxel grid. + The vertices are centered at the origin, with the length of each + dimension in the grid being one. + """ + base_cube_f = torch.arange(8).to(self.device) + if isinstance(res, int): + res = (res, res, res) + voxel_grid_template = torch.ones(res, device=self.device) + + res = torch.tensor([res], dtype=torch.float, device=self.device) + coords = torch.nonzero(voxel_grid_template).float() / res # N, 3 + verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3) + cubes = (base_cube_f.unsqueeze(0) + + torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1) + + verts_rounded = torch.round(verts * 10**5) / (10**5) + verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True) + cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8) + + return verts_unique - 0.5, cubes + + def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None, + gamma_f=None, training=False, output_tetmesh=False, grad_func=None): + r""" + Main function for mesh extraction from scalar field using FlexiCubes. This function converts + discrete signed distance fields, encoded on voxel grids and additional per-cube parameters, + to triangle or tetrahedral meshes using a differentiable operation as described in + `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances + mesh quality and geometric fidelity by adjusting the surface representation based on gradient + optimization. The output surface is differentiable with respect to the input vertex positions, + scalar field values, and weight parameters. + + If you intend to extract a surface mesh from a fixed Signed Distance Field without the + optimization of parameters, it is suggested to provide the "grad_func" which should + return the surface gradient at any given 3D position. When grad_func is provided, the process + to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as + described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy. + Please note, this approach is non-differentiable. + + For more details and example usage in optimization, refer to the + `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper. + + Args: + x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed. + s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values + denote that the corresponding vertex resides inside the isosurface. This affects + the directions of the extracted triangle faces and volume to be tetrahedralized. + cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid. + res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it + is used for all three dimensions. If a list or tuple of 3 integers is provided, they + specify the resolution for the x, y, and z dimensions respectively. + beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual + vertices positioning. Defaults to uniform value for all edges. + alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual + vertices positioning. Defaults to uniform value for all vertices. + gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of + quadrilaterals into triangles. Defaults to uniform value for all cubes. + training (bool, optional): If set to True, applies differentiable quad splitting for + training. Defaults to False. + output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise, + outputs a triangular mesh. Defaults to False. + grad_func (callable, optional): A function to compute the surface gradient at specified + 3D positions (input: Nx3 positions). The function should return gradients as an Nx3 + tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None. + + Returns: + (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing: + - Vertices for the extracted triangular/tetrahedral mesh. + - Faces for the extracted triangular/tetrahedral mesh. + - Regularizer L_dev, computed per dual vertex. + + .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization: + https://research.nvidia.com/labs/toronto-ai/flexicubes/ + .. _Manifold Dual Contouring: + https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf + """ + + surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8) + if surf_cubes.sum() == 0: + return torch.zeros( + (0, 3), + device=self.device), torch.zeros( + (0, 4), + dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros( + (0, 3), + dtype=torch.long, device=self.device), torch.zeros( + (0), + device=self.device) + beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes) + + case_ids = self._get_case_id(occ_fx8, surf_cubes, res) + + surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes) + + vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd( + x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func) + vertices, faces, s_edges, edge_indices = self._triangulate( + s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func) + if not output_tetmesh: + return vertices, faces, L_dev + else: + vertices, tets = self._tetrahedralize( + x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, + surf_cubes, training) + return vertices, tets, L_dev + + def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): + """ + Regularizer L_dev as in Equation 8 + """ + dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1) + mean_l2 = torch.zeros_like(vd[:, 0]) + mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float() + mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs() + return mad + + def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes): + """ + Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. + """ + n_cubes = surf_cubes.shape[0] + + if beta_fx12 is not None: + beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1) + else: + beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) + + if alpha_fx8 is not None: + alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1) + else: + alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) + + if gamma_f is not None: + gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2 + else: + gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device) + + return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes] + + @torch.no_grad() + def _get_case_id(self, occ_fx8, surf_cubes, res): + """ + Obtains the ID of topology cases based on cell corner occupancy. This function resolves the + ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the + supplementary material. It should be noted that this function assumes a regular grid. + """ + case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1) + + problem_config = self.check_table.to(self.device)[case_ids] + to_check = problem_config[..., 0] == 1 + problem_config = problem_config[to_check] + if not isinstance(res, (list, tuple)): + res = [res, res, res] + + # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, + # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). + # This allows efficient checking on adjacent cubes. + problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long) + vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 + vol_idx_problem = vol_idx[surf_cubes][to_check] + problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config + vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] + + within_range = ( + vol_idx_problem_adj[..., 0] >= 0) & ( + vol_idx_problem_adj[..., 0] < res[0]) & ( + vol_idx_problem_adj[..., 1] >= 0) & ( + vol_idx_problem_adj[..., 1] < res[1]) & ( + vol_idx_problem_adj[..., 2] >= 0) & ( + vol_idx_problem_adj[..., 2] < res[2]) + + vol_idx_problem = vol_idx_problem[within_range] + vol_idx_problem_adj = vol_idx_problem_adj[within_range] + problem_config = problem_config[within_range] + problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0], + vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]] + # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. + to_invert = (problem_config_adj[..., 0] == 1) + idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert] + case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) + return case_ids + + @torch.no_grad() + def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes): + """ + Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge + can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge + and marks the cube edges with this index. + """ + occ_n = s_n < 0 + all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2) + unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + + surf_edges_mask = mask_edges[_idx_map] + counts = counts[_idx_map] + + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device) + # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index + # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. + idx_map = mapping[_idx_map] + surf_edges = unique_edges[mask_edges] + return surf_edges, idx_map, counts, surf_edges_mask + + @torch.no_grad() + def _identify_surf_cubes(self, s_n, cube_fx8): + """ + Identifies grid cubes that intersect with the underlying surface by checking if the signs at + all corners are not identical. + """ + occ_n = s_n < 0 + occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) + _occ_sum = torch.sum(occ_fx8, -1) + surf_cubes = (_occ_sum > 0) & (_occ_sum < 8) + return surf_cubes, occ_fx8 + + def _linear_interp(self, edges_weight, edges_x): + """ + Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. + """ + edge_dim = edges_weight.dim() - 2 + assert edges_weight.shape[edge_dim] == 2 + edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), - + torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim) + denominator = edges_weight.sum(edge_dim) + ue = (edges_x * edges_weight).sum(edge_dim) / denominator + return ue + + def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None): + p_bxnx3 = p_bxnx3.reshape(-1, 7, 3) + norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3) + c_bx3 = c_bx3.reshape(-1, 3) + A = norm_bxnx3 + B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) + + A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1) + B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1) + A = torch.cat([A, A_reg], 1) + B = torch.cat([B, B_reg], 1) + dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) + return dual_verts + + def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func): + """ + Computes the location of dual vertices as described in Section 4.2 + """ + alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2) + surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3) + surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) + zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) + + idx_map = idx_map.reshape(-1, 12) + num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) + edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], [] + + total_num_vd = 0 + vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False) + if grad_func is not None: + normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1) + vd = [] + for num in torch.unique(num_vd): + cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching) + curr_num_vd = cur_cubes.sum() * num + curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7) + curr_edge_group_to_vd = torch.arange( + curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd + total_num_vd += curr_num_vd + curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[ + cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group) + + curr_mask = (curr_edge_group != -1) + edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) + edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask)) + edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask)) + vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) + vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1)) + + if grad_func is not None: + with torch.no_grad(): + cube_e_verts_idx = idx_map[cur_cubes] + curr_edge_group[~curr_mask] = 0 + + verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group) + verts_group_idx[verts_group_idx == -1] = 0 + verts_group_pos = torch.index_select( + input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3) + v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1) + curr_mask = curr_mask.reshape(-1, num.item(), 7, 1) + verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2)) + + normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape( + -1, num.item(), 7, + 3) + curr_mask = curr_mask.squeeze(2) + vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask, + verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3)) + edge_group = torch.cat(edge_group) + edge_group_to_vd = torch.cat(edge_group_to_vd) + edge_group_to_cube = torch.cat(edge_group_to_cube) + vd_num_edges = torch.cat(vd_num_edges) + vd_gamma = torch.cat(vd_gamma) + + if grad_func is not None: + vd = torch.cat(vd) + L_dev = torch.zeros([1], device=self.device) + else: + vd = torch.zeros((total_num_vd, 3), device=self.device) + beta_sum = torch.zeros((total_num_vd, 1), device=self.device) + + idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group) + + x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3) + s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) + + zero_crossing_group = torch.index_select( + input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3) + + alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1) + ue_group = self._linear_interp(s_group * alpha_group, x_group) + + beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1) + beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) + vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum + L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) + + v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd + + vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube * + 12 + edge_group, src=v_idx[edge_group_to_vd]) + + return vd, L_dev, vd_gamma, vd_idx_map + + def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func): + """ + Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into + triangles based on the gamma parameter, as described in Section 4.3. + """ + with torch.no_grad(): + group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes. + group = idx_map.reshape(-1)[group_mask] + vd_idx = vd_idx_map[group_mask] + edge_indices, indices = torch.sort(group, stable=True) + quad_vd_idx = vd_idx[indices].reshape(-1, 4) + + # Ensure all face directions point towards the positive SDF to maintain consistent winding. + s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2) + flip_mask = s_edges[:, 0] > 0 + quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], + quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]])) + if grad_func is not None: + # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients. + with torch.no_grad(): + vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1) + quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True) + gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True) + else: + quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4) + gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor( + 0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1) + gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor( + 1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1) + if not training: + mask = (gamma_02 > gamma_13).squeeze(1) + faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device) + faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] + faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] + faces = faces.reshape(-1, 3) + else: + vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) + + torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2 + vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) + + torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2 + weight_sum = (gamma_02 + gamma_13) + 1e-8 + vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / + weight_sum.unsqueeze(-1)).squeeze(1) + vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] + vd = torch.cat([vd, vd_center]) + faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) + faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3) + return vd, faces, s_edges, edge_indices + + def _tetrahedralize( + self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, + surf_cubes, training): + """ + Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5. + """ + occ_n = s_n < 0 + occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) + occ_sum = torch.sum(occ_fx8, -1) + + inside_verts = x_nx3[occ_n] + mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1 + mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0] + """ + For each grid edge connecting two grid vertices with different + signs, we first form a four-sided pyramid by connecting one + of the grid vertices with four mesh vertices that correspond + to the grid edge and then subdivide the pyramid into two tetrahedra + """ + inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[ + s_edges < 0]] + if not training: + inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1) + else: + inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1) + + tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1) + """ + For each grid edge connecting two grid vertices with the + same sign, the tetrahedron is formed by the two grid vertices + and two vertices in consecutive adjacent cells + """ + inside_cubes = (occ_sum == 8) + inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1) + inside_cubes_center_idx = torch.arange( + inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0] + + surface_n_inside_cubes = surf_cubes | inside_cubes + edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13), + dtype=torch.long, device=x_nx3.device) * -1 + surf_cubes = surf_cubes[surface_n_inside_cubes] + inside_cubes = inside_cubes[surface_n_inside_cubes] + edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12) + edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx + + all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2) + unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2 + mask = mask_edges[_idx_map] + counts = counts[_idx_map] + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device) + idx_map = mapping[_idx_map] + + group_mask = (counts == 4) & mask + group = idx_map.reshape(-1)[group_mask] + edge_indices, indices = torch.sort(group) + cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long, + device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask] + edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze( + 0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask] + # Identify the face shared by the adjacent cells. + cube_idx_4 = cube_idx[indices].reshape(-1, 4) + edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0] + shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1) + cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1) + # Identify an edge of the face with different signs and + # select the mesh vertex corresponding to the identified edge. + case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255 + case_ids_expand[surf_cubes] = case_ids + cases = case_ids_expand[cube_idx_4x2] + quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2) + mask = (quad_edge == -1).sum(-1) == 0 + inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2) + tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask] + + tets = torch.cat([tets_surface, tets_inside]) + vertices = torch.cat([vertices, inside_verts, inside_cubes_center]) + return vertices, tets diff --git a/util/flexicubes_geometry.py b/util/flexicubes_geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..055805f3ad1396cef768e02e2211b4b3b5601c82 --- /dev/null +++ b/util/flexicubes_geometry.py @@ -0,0 +1,116 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +from util.flexicubes import FlexiCubes # replace later +# from dmtet import sdf_reg_loss_batch +import torch.nn.functional as F + +def get_center_boundary_index(grid_res, device): + v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device) + v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True + center_indices = torch.nonzero(v.reshape(-1)) + + v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False + v[:2, ...] = True + v[-2:, ...] = True + v[:, :2, ...] = True + v[:, -2:, ...] = True + v[:, :, :2] = True + v[:, :, -2:] = True + boundary_indices = torch.nonzero(v.reshape(-1)) + return center_indices, boundary_indices + +############################################################################### +# Geometry interface +############################################################################### +class FlexiCubesGeometry(object): + def __init__( + self, grid_res=64, scale=2.0, device='cuda', renderer=None, + render_type='neural_render', args=None): + super(FlexiCubesGeometry, self).__init__() + self.grid_res = grid_res + self.device = device + self.args = args + self.fc = FlexiCubes(device, weight_scale=0.5) + self.verts, self.indices = self.fc.construct_voxel_grid(grid_res) + if isinstance(scale, list): + self.verts[:, 0] = self.verts[:, 0] * scale[0] + self.verts[:, 1] = self.verts[:, 1] * scale[1] + self.verts[:, 2] = self.verts[:, 2] * scale[1] + else: + self.verts = self.verts * scale + + all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2) + self.all_edges = torch.unique(all_edges, dim=0) + + # Parameters used for fix boundary sdf + self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device) + self.renderer = renderer + self.render_type = render_type + + def getAABB(self): + return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values + + def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False): + if indices is None: + indices = self.indices + + verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res, + beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20], + gamma_f=weight_n[:, 20], training=is_training + ) + return verts, faces, v_reg_loss + + + def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False): + return_value = dict() + if self.render_type == 'neural_render': + tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh( + mesh_v_nx3.unsqueeze(dim=0), + mesh_f_fx3.int(), + camera_mv_bx4x4, + mesh_v_nx3.unsqueeze(dim=0), + resolution=resolution, + device=self.device, + hierarchical_mask=hierarchical_mask + ) + + return_value['tex_pos'] = tex_pos + return_value['mask'] = mask + return_value['hard_mask'] = hard_mask + return_value['rast'] = rast + return_value['v_pos_clip'] = v_pos_clip + return_value['mask_pyramid'] = mask_pyramid + return_value['depth'] = depth + else: + raise NotImplementedError + + return return_value + + def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): + # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 + v_list = [] + f_list = [] + n_batch = v_deformed_bxnx3.shape[0] + all_render_output = [] + for i_batch in range(n_batch): + verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) + v_list.append(verts_nx3) + f_list.append(faces_fx3) + render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) + all_render_output.append(render_output) + + # Concatenate all render output + return_keys = all_render_output[0].keys() + return_value = dict() + for k in return_keys: + value = [v[k] for v in all_render_output] + return_value[k] = value + # We can do concatenation outside of the render + return return_value diff --git a/util/renderer.py b/util/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..a37627cb7782777c21eb26bf8aa9909c0da9ca25 --- /dev/null +++ b/util/renderer.py @@ -0,0 +1,49 @@ + +import torch +import torch.nn as nn +import nvdiffrast.torch as dr +from util.flexicubes_geometry import FlexiCubesGeometry + +class Renderer(nn.Module): + def __init__(self, tet_grid_size, camera_angle_num, scale, geo_type): + super().__init__() + + self.tet_grid_size = tet_grid_size + self.camera_angle_num = camera_angle_num + self.scale = scale + self.geo_type = geo_type + self.glctx = dr.RasterizeCudaContext() + + if self.geo_type == "flex": + self.flexicubes = FlexiCubesGeometry(grid_res = self.tet_grid_size) + + def forward(self, data, sdf, deform, verts, tets, training=False, weight = None): + + results = {} + + deform = torch.tanh(deform) / self.tet_grid_size * self.scale / 0.95 + if self.geo_type == "flex": + deform = deform *0.5 + + v_deformed = verts + deform + + verts_list = [] + faces_list = [] + reg_list = [] + n_shape = verts.shape[0] + for i in range(n_shape): + verts_i, faces_i, reg_i = self.flexicubes.get_mesh(v_deformed[i], sdf[i].squeeze(dim=-1), + with_uv=False, indices=tets, weight_n=weight[i], is_training=training) + + verts_list.append(verts_i) + faces_list.append(faces_i) + reg_list.append(reg_i) + verts = verts_list + faces = faces_list + + flexicubes_surface_reg = torch.cat(reg_list).mean() + flexicubes_weight_reg = (weight ** 2).mean() + results["flex_surf_loss"] = flexicubes_surface_reg + results["flex_weight_loss"] = flexicubes_weight_reg + + return results, verts, faces \ No newline at end of file diff --git a/util/tables.py b/util/tables.py new file mode 100644 index 0000000000000000000000000000000000000000..5873e7727b5595a1e4fbc3bd10ae5be8f3d06cca --- /dev/null +++ b/util/tables.py @@ -0,0 +1,791 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +dmc_table = [ +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]] +] +num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2, +2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, +1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1, +1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, +2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, +3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, +2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, +1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, +1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, +1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] +check_table = [ +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 194], +[1, -1, 0, 0, 193], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 164], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 161], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 152], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 145], +[1, 0, 0, 1, 144], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 137], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 133], +[1, 0, 1, 0, 132], +[1, 1, 0, 0, 131], +[1, 1, 0, 0, 130], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 100], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 98], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 96], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 88], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 82], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 74], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 72], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 70], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 67], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 65], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 56], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 52], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 44], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 40], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 38], +[1, 0, -1, 0, 37], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 33], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 28], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 26], +[1, 0, 0, -1, 25], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 20], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 18], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 9], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 6], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0] +] +tet_table = [ +[-1, -1, -1, -1, -1, -1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, -1], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, -1], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, -1, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, -1, 2, 4, 4, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, 5, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, -1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[-1, 1, 1, 4, 4, 1], +[0, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[8, 8, 8, 8, 8, 8], +[1, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 4, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 5, 5, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[6, 6, 6, 6, 6, 6], +[6, -1, 0, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 4, -1, 6, 4, 6], +[6, 4, 0, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 2, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 1, 1, 6, -1, 6], +[6, 1, 1, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 4], +[2, 2, 2, 2, 2, 2], +[6, 1, 1, 6, 4, 6], +[6, 1, 1, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 5, 0, 5, 0, 5], +[5, 5, 5, 5, 5, 5], +[5, 5, 5, 5, 5, 5], +[0, 5, 0, 5, 0, 5], +[-1, 5, 0, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[4, 5, -1, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[4, 5, 0, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 6, 6, 6, 6, 6], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, -1, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[2, 5, 2, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 4], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 6, 2, 6, 6, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 1, 4, 1], +[0, 1, 1, 1, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 0, 0, 6, 0, 6], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[5, 5, 5, 5, 5, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 4, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[4, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[8, 8, 8, 8, 8, 8], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 1, 1, 4, 4, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 4, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[12, 12, 12, 12, 12, 12] +] diff --git a/util/utils.py b/util/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e4cb35ef3712f11987c5a4432aee95395d93e6b7 --- /dev/null +++ b/util/utils.py @@ -0,0 +1,194 @@ +import numpy as np +import torch +import random + + +# Reworked so this matches gluPerspective / glm::perspective, using fovy +def perspective(fovx=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None): + # y = np.tan(fovy / 2) + x = np.tan(fovx / 2) + return torch.tensor([[1/x, 0, 0, 0], + [ 0, -aspect/x, 0, 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]], dtype=torch.float32, device=device) + + +def translate(x, y, z, device=None): + return torch.tensor([[1, 0, 0, x], + [0, 1, 0, y], + [0, 0, 1, z], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + + +def rotate_x(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[1, 0, 0, 0], + [0, c, -s, 0], + [0, s, c, 0], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + + +def rotate_y(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[ c, 0, s, 0], + [ 0, 1, 0, 0], + [-s, 0, c, 0], + [ 0, 0, 0, 1]], dtype=torch.float32, device=device) + + +def rotate_z(a, device=None): + s, c = np.sin(a), np.cos(a) + return torch.tensor([[c, -s, 0, 0], + [s, c, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]], dtype=torch.float32, device=device) + +@torch.no_grad() +def batch_random_rotation_translation(b, t, device=None): + m = np.random.normal(size=[b, 3, 3]) + m[:, 1] = np.cross(m[:, 0], m[:, 2]) + m[:, 2] = np.cross(m[:, 0], m[:, 1]) + m = m / np.linalg.norm(m, axis=2, keepdims=True) + m = np.pad(m, [[0, 0], [0, 1], [0, 1]], mode='constant') + m[:, 3, 3] = 1.0 + m[:, :3, 3] = np.random.uniform(-t, t, size=[b, 3]) + return torch.tensor(m, dtype=torch.float32, device=device) + +@torch.no_grad() +def random_rotation_translation(t, device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.random.uniform(-t, t, size=[3]) + return torch.tensor(m, dtype=torch.float32, device=device) + + +@torch.no_grad() +def random_rotation(device=None): + m = np.random.normal(size=[3, 3]) + m[1] = np.cross(m[0], m[2]) + m[2] = np.cross(m[0], m[1]) + m = m / np.linalg.norm(m, axis=1, keepdims=True) + m = np.pad(m, [[0, 1], [0, 1]], mode='constant') + m[3, 3] = 1.0 + m[:3, 3] = np.array([0,0,0]).astype(np.float32) + return torch.tensor(m, dtype=torch.float32, device=device) + + +def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.sum(x*y, -1, keepdim=True) + + +def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: + return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN + + +def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: + return x / length(x, eps) + + +def lr_schedule(iter, warmup_iter, scheduler_decay): + if iter < warmup_iter: + return iter / warmup_iter + return max(0.0, 10 ** ( + -(iter - warmup_iter) * scheduler_decay)) + + +def trans_depth(depth): + depth = depth[0].detach().cpu().numpy() + valid = depth > 0 + depth[valid] -= depth[valid].min() + depth[valid] = ((depth[valid] / depth[valid].max()) * 255) + return depth.astype('uint8') + + +def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + + +def load_item(filepath): + with open(filepath, 'r') as f: + items = [name.strip() for name in f.readlines()] + return set(items) + +def load_prompt(filepath): + uuid2prompt = {} + with open(filepath, 'r') as f: + for line in f.readlines(): + list_line = line.split(',') + uuid2prompt[list_line[0]] = ','.join(list_line[1:]).strip() + return uuid2prompt + +def resize_and_center_image(image_tensor, scale=0.95, c = 0, shift = 0, rgb=False, aug_shift = 0): + if scale == 1: + return image_tensor + B, C, H, W = image_tensor.shape + new_H, new_W = int(H * scale), int(W * scale) + resized_image = torch.nn.functional.interpolate(image_tensor, size=(new_H, new_W), mode='bilinear', align_corners=False).squeeze(0) + background = torch.zeros_like(image_tensor) + c + start_y, start_x = (H - new_H) // 2, (W - new_W) // 2 + if shift == 0: + background[:, :, start_y:start_y + new_H, start_x:start_x + new_W] = resized_image + else: + for i in range(B): + randx = random.randint(-shift, shift) + randy = random.randint(-shift, shift) + if rgb == True: + if i == 0 or i==2 or i==4: + randx = 0 + randy = 0 + background[i, :, start_y+randy:start_y + new_H+randy, start_x+randx:start_x + new_W+randx] = resized_image[i] + if aug_shift == 0: + return background + for i in range(B): + for j in range(C): + background[i, j, :, :] += (random.random() - 0.5)*2 * aug_shift / 255 + return background + +def get_tri(triview_color, dim = 1, blender=True, c = 0, scale=0.95, shift = 0, fix = False, rgb=False, aug_shift = 0): + # triview_color: [6,C,H,W] + # rgb is useful when shift is not 0 + triview_color = resize_and_center_image(triview_color, scale=scale, c = c, shift=shift,rgb=rgb, aug_shift = aug_shift) + if blender is False: + triview_color0 = torch.rot90(triview_color[0],k=2,dims=[1,2]) + triview_color1 = torch.rot90(triview_color[4],k=1,dims=[1,2]).flip(2).flip(1) + triview_color2 = torch.rot90(triview_color[5],k=1,dims=[1,2]).flip(2) + triview_color3 = torch.rot90(triview_color[3],k=2,dims=[1,2]).flip(2) + triview_color4 = torch.rot90(triview_color[1],k=3,dims=[1,2]).flip(1) + triview_color5 = torch.rot90(triview_color[2],k=3,dims=[1,2]).flip(1).flip(2) + else: + triview_color0 = torch.rot90(triview_color[2],k=2,dims=[1,2]) + triview_color1 = torch.rot90(triview_color[4],k=0,dims=[1,2]).flip(2).flip(1) + triview_color2 = torch.rot90(torch.rot90(triview_color[0],k=3,dims=[1,2]).flip(2), k=2,dims=[1,2]) + triview_color3 = torch.rot90(torch.rot90(triview_color[5],k=2,dims=[1,2]).flip(2), k=2,dims=[1,2]) + triview_color4 = torch.rot90(triview_color[1],k=2,dims=[1,2]).flip(1).flip(1).flip(2) + triview_color5 = torch.rot90(triview_color[3],k=1,dims=[1,2]).flip(1).flip(2) + if fix == True: + triview_color0[1] = triview_color0[1] * 0 + triview_color0[2] = triview_color0[2] * 0 + triview_color3[1] = triview_color3[1] * 0 + triview_color3[2] = triview_color3[2] * 0 + + triview_color1[0] = triview_color1[0] * 0 + triview_color1[1] = triview_color1[1] * 0 + triview_color4[0] = triview_color4[0] * 0 + triview_color4[1] = triview_color4[1] * 0 + + triview_color2[0] = triview_color2[0] * 0 + triview_color2[2] = triview_color2[2] * 0 + triview_color5[0] = triview_color5[0] * 0 + triview_color5[2] = triview_color5[2] * 0 + color_tensor1_gt = torch.cat((triview_color0, triview_color1, triview_color2), dim=2) + color_tensor2_gt = torch.cat((triview_color3, triview_color4, triview_color5), dim=2) + color_tensor_gt = torch.cat((color_tensor1_gt, color_tensor2_gt), dim = dim) + return color_tensor_gt +