diff --git a/.ipynb_checkpoints/README-checkpoint.md b/.ipynb_checkpoints/README-checkpoint.md
new file mode 100644
index 0000000000000000000000000000000000000000..2a44cdb976e26c64ec481be085401da0d7508ffe
--- /dev/null
+++ b/.ipynb_checkpoints/README-checkpoint.md
@@ -0,0 +1,68 @@
+**中文版本 [中文](README_zh.md)**
+
+# Unique3D
+High-Quality and Efficient 3D Mesh Generation from a Single Image
+
+## [Paper]() | [Project page](https://wukailu.github.io/Unique3D/) | [Huggingface Demo]() | [Online Demo](https://www.aiuni.ai/)
+
+![](assets/fig_teaser.png)
+
+High-fidelity and diverse textured meshes generated by Unique3D from single-view wild images in 30 seconds.
+
+## More features
+
+The repo is still being under construction, thanks for your patience.
+- [x] Local gradio demo.
+- [ ] Detailed tutorial.
+- [ ] Huggingface demo.
+- [ ] Detailed local demo.
+- [ ] Comfyui support.
+- [ ] Windows support.
+- [ ] Docker support.
+- [ ] More stable reconstruction with normal.
+- [ ] Training code release.
+
+## Preparation for inference
+
+### Linux System Setup.
+```angular2html
+conda create -n unique3d
+conda activate unique3d
+pip install -r requirements.txt
+```
+
+### Interactive inference: run your local gradio demo.
+
+1. Download the [ckpt.zip](), and extract it to `ckpt/*`.
+```
+Unique3D
+ ├──ckpt
+ ├── controlnet-tile/
+ ├── image2normal/
+ ├── img2mvimg/
+ ├── realesrgan-x4.onnx
+ └── v1-inference.yaml
+```
+
+2. Run the interactive inference locally.
+```bash
+python app/gradio_local.py --port 7860
+```
+
+## Tips to get better results
+
+1. Unique3D is sensitive to the facing direction of input images. Due to the distribution of the training data, orthographic front-facing images with a rest pose always lead to good reconstructions.
+2. Images with occlusions will cause worse reconstructions, since four views cannot cover the complete object. Images with fewer occlusions lead to better results.
+3. Pass an image with as high a resolution as possible to the input when resolution is a factor.
+
+## Acknowledgement
+
+We have intensively borrowed code from the following repositories. Many thanks to the authors for sharing their code.
+- [Stable Diffusion](https://github.com/CompVis/stable-diffusion)
+- [Wonder3d](https://github.com/xxlong0/Wonder3D)
+- [Zero123Plus](https://github.com/SUDO-AI-3D/zero123plus)
+- [Continues Remeshing](https://github.com/Profactor/continuous-remeshing)
+- [Depth from Normals](https://github.com/YertleTurtleGit/depth-from-normals)
+
+## Collaborations
+Our mission is to create a 4D generative model with 3D concepts. This is just our first step, and the road ahead is still long, but we are confident. We warmly invite you to join the discussion and explore potential collaborations in any capacity. **If you're interested in connecting or partnering with us, please don't hesitate to reach out via email (wkl22@mails.tsinghua.edu.cn)**.
\ No newline at end of file
diff --git a/README_zh.md b/README_zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..22e82c80833d28b01024aad82653a398e6011e1a
--- /dev/null
+++ b/README_zh.md
@@ -0,0 +1,56 @@
+**其他语言版本 [English](README.md)**
+
+# Unique3D
+High-Quality and Efficient 3D Mesh Generation from a Single Image
+
+## [论文]() | [项目页面](https://wukailu.github.io/Unique3D/) | [Huggingface Demo]() | [在线演示](https://www.aiuni.ai/)
+
+![](assets/fig_teaser.png)
+
+Unique3D从单视图图像生成高保真度和多样化纹理的网格,在4090上大约需要30秒。
+
+### 推理准备
+
+#### Linux系统设置
+```angular2html
+conda create -n unique3d
+conda activate unique3d
+pip install -r requirements.txt
+```
+
+#### 交互式推理:运行您的本地gradio演示
+
+1. 下载[ckpt.zip](),并将其解压到`ckpt/*`。
+```
+Unique3D
+ ├──ckpt
+ ├── controlnet-tile/
+ ├── image2normal/
+ ├── img2mvimg/
+ ├── realesrgan-x4.onnx
+ └── v1-inference.yaml
+```
+
+2. 在本地运行交互式推理。
+```bash
+python app/gradio_local.py --port 7860
+```
+
+## 获取更好结果的提示
+
+1. Unique3D对输入图像的朝向非常敏感。由于训练数据的分布,**正交正视图像**通常总是能带来良好的重建。对于人物而言,最好是 A-pose 或者 T-pose,因为目前训练数据很少含有其他类型姿态。
+2. 有遮挡的图像会导致更差的重建,因为4个视图无法覆盖完整的对象。遮挡较少的图像会带来更好的结果。
+3. 尽可能将高分辨率的图像用作输入。
+
+## 致谢
+
+我们借用了以下代码库的代码。非常感谢作者们分享他们的代码。
+- [Stable Diffusion](https://github.com/CompVis/stable-diffusion)
+- [Wonder3d](https://github.com/xxlong0/Wonder3D)
+- [Zero123Plus](https://github.com/SUDO-AI-3D/zero123plus)
+- [Continues Remeshing](https://github.com/Profactor/continuous-remeshing)
+- [Depth from Normals](https://github.com/YertleTurtleGit/depth-from-normals)
+
+## 合作
+
+我们使命是创建一个具有3D概念的4D生成模型。这只是我们的第一步,前方的道路仍然很长,但我们有信心。我们热情邀请您加入讨论,并探索任何形式的潜在合作。**如果您有兴趣联系或与我们合作,欢迎通过电子邮件(wkl22@mails.tsinghua.edu.cn)与我们联系**。
diff --git a/app/__init__.py b/app/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/app/all_models.py b/app/all_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7df963350c704c864b73d1e403f880f86bdfd5d
--- /dev/null
+++ b/app/all_models.py
@@ -0,0 +1,22 @@
+import torch
+from scripts.sd_model_zoo import load_common_sd15_pipe
+from diffusers import StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline
+
+
+class MyModelZoo:
+ _pipe_disney_controlnet_lineart_ipadapter_i2i: StableDiffusionControlNetImg2ImgPipeline = None
+
+ base_model = "runwayml/stable-diffusion-v1-5"
+
+ def __init__(self, base_model=None) -> None:
+ if base_model is not None:
+ self.base_model = base_model
+
+ @property
+ def pipe_disney_controlnet_tile_ipadapter_i2i(self):
+ return self._pipe_disney_controlnet_lineart_ipadapter_i2i
+
+ def init_models(self):
+ self._pipe_disney_controlnet_lineart_ipadapter_i2i = load_common_sd15_pipe(base_model=self.base_model, ip_adapter=True, plus_model=False, controlnet="./ckpt/controlnet-tile", pipeline_class=StableDiffusionControlNetImg2ImgPipeline)
+
+model_zoo = MyModelZoo()
diff --git a/app/custom_models/image2image-objaverseF-rgb2normal.yaml b/app/custom_models/image2image-objaverseF-rgb2normal.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..de5e9e871d77eaa20ffe0d1488f91c82e39a8542
--- /dev/null
+++ b/app/custom_models/image2image-objaverseF-rgb2normal.yaml
@@ -0,0 +1,61 @@
+pretrained_model_name_or_path: "lambdalabs/sd-image-variations-diffusers"
+mixed_precision: "bf16"
+
+init_config:
+ # enable controls
+ enable_cross_attn_lora: False
+ enable_cross_attn_ip: False
+ enable_self_attn_lora: False
+ enable_self_attn_ref: True
+ enable_multiview_attn: False
+
+ # for cross attention
+ init_cross_attn_lora: False
+ init_cross_attn_ip: False
+ cross_attn_lora_rank: 512 # 0 for not enabled
+ cross_attn_lora_only_kv: False
+ ipadapter_pretrained_name: "h94/IP-Adapter"
+ ipadapter_subfolder_name: "models"
+ ipadapter_weight_name: "ip-adapter_sd15.safetensors"
+ ipadapter_effect_on: "all" # all, first
+
+ # for self attention
+ init_self_attn_lora: False
+ self_attn_lora_rank: 512
+ self_attn_lora_only_kv: False
+
+ # for self attention ref
+ init_self_attn_ref: True
+ self_attn_ref_position: "attn1"
+ self_attn_ref_other_model_name: "lambdalabs/sd-image-variations-diffusers"
+ self_attn_ref_pixel_wise_crosspond: True
+ self_attn_ref_effect_on: "all"
+
+ # for multiview attention
+ init_multiview_attn: False
+ multiview_attn_position: "attn1"
+ num_modalities: 1
+
+ # for unet
+ init_unet_path: "${pretrained_model_name_or_path}"
+ init_num_cls_label: 0 # for initialize
+ cls_labels: [] # for current task
+
+trainers:
+ - trainer_type: "image2image_trainer"
+ trainer:
+ pretrained_model_name_or_path: "${pretrained_model_name_or_path}"
+ attn_config:
+ cls_labels: [] # for current task
+ enable_cross_attn_lora: False
+ enable_cross_attn_ip: False
+ enable_self_attn_lora: False
+ enable_self_attn_ref: True
+ enable_multiview_attn: False
+ resolution: "512"
+ condition_image_resolution: "512"
+ condition_image_column_name: "conditioning_image"
+ image_column_name: "image"
+
+
+
diff --git a/app/custom_models/image2mvimage-objaverseFrot-wonder3d.yaml b/app/custom_models/image2mvimage-objaverseFrot-wonder3d.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..07ad06caeae33c598a929d8f3c4595bd403da32d
--- /dev/null
+++ b/app/custom_models/image2mvimage-objaverseFrot-wonder3d.yaml
@@ -0,0 +1,63 @@
+pretrained_model_name_or_path: "./ckpt/img2mvimg"
+mixed_precision: "bf16"
+
+init_config:
+ # enable controls
+ enable_cross_attn_lora: False
+ enable_cross_attn_ip: False
+ enable_self_attn_lora: False
+ enable_self_attn_ref: False
+ enable_multiview_attn: True
+
+ # for cross attention
+ init_cross_attn_lora: False
+ init_cross_attn_ip: False
+ cross_attn_lora_rank: 256 # 0 for not enabled
+ cross_attn_lora_only_kv: False
+ ipadapter_pretrained_name: "h94/IP-Adapter"
+ ipadapter_subfolder_name: "models"
+ ipadapter_weight_name: "ip-adapter_sd15.safetensors"
+ ipadapter_effect_on: "all" # all, first
+
+ # for self attention
+ init_self_attn_lora: False
+ self_attn_lora_rank: 256
+ self_attn_lora_only_kv: False
+
+ # for self attention ref
+ init_self_attn_ref: False
+ self_attn_ref_position: "attn1"
+ self_attn_ref_other_model_name: "lambdalabs/sd-image-variations-diffusers"
+ self_attn_ref_pixel_wise_crosspond: False
+ self_attn_ref_effect_on: "all"
+
+ # for multiview attention
+ init_multiview_attn: True
+ multiview_attn_position: "attn1"
+ use_mv_joint_attn: True
+ num_modalities: 1
+
+ # for unet
+ init_unet_path: "${pretrained_model_name_or_path}"
+ cat_condition: True # cat condition to input
+
+ # for cls embedding
+ init_num_cls_label: 8 # for initialize
+ cls_labels: [0, 1, 2, 3] # for current task
+
+trainers:
+ - trainer_type: "image2mvimage_trainer"
+ trainer:
+ pretrained_model_name_or_path: "${pretrained_model_name_or_path}"
+ attn_config:
+ cls_labels: [0, 1, 2, 3] # for current task
+ enable_cross_attn_lora: False
+ enable_cross_attn_ip: False
+ enable_self_attn_lora: False
+ enable_self_attn_ref: False
+ enable_multiview_attn: True
+ resolution: "256"
+ condition_image_resolution: "256"
+ normal_cls_offset: 4
+ condition_image_column_name: "conditioning_image"
+ image_column_name: "image"
\ No newline at end of file
diff --git a/app/custom_models/mvimg_prediction.py b/app/custom_models/mvimg_prediction.py
new file mode 100644
index 0000000000000000000000000000000000000000..56aa10927481cfa1201538ba349d2202c41aeddf
--- /dev/null
+++ b/app/custom_models/mvimg_prediction.py
@@ -0,0 +1,57 @@
+import sys
+import torch
+import gradio as gr
+from PIL import Image
+import numpy as np
+from rembg import remove
+from app.utils import change_rgba_bg, rgba_to_rgb
+from app.custom_models.utils import load_pipeline
+from scripts.all_typing import *
+from scripts.utils import session, simple_preprocess
+
+training_config = "app/custom_models/image2mvimage-objaverseFrot-wonder3d.yaml"
+checkpoint_path = "ckpt/img2mvimg/unet_state_dict.pth"
+trainer, pipeline = load_pipeline(training_config, checkpoint_path)
+pipeline.enable_model_cpu_offload()
+
+def predict(img_list: List[Image.Image], guidance_scale=2., **kwargs):
+ if isinstance(img_list, Image.Image):
+ img_list = [img_list]
+ img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list]
+ ret = []
+ for img in img_list:
+ images = trainer.pipeline_forward(
+ pipeline=pipeline,
+ image=img,
+ guidance_scale=guidance_scale,
+ **kwargs
+ ).images
+ ret.extend(images)
+ return ret
+
+
+def run_mvprediction(input_image: Image.Image, remove_bg=True, guidance_scale=1.5, seed=1145):
+ if input_image.mode == 'RGB' or np.array(input_image)[..., -1].mean() == 255.:
+ # still do remove using rembg, since simple_preprocess requires RGBA image
+ print("RGB image not RGBA! still remove bg!")
+ remove_bg = True
+
+ if remove_bg:
+ input_image = remove(input_image, session=session)
+
+ # make front_pil RGBA with white bg
+ input_image = change_rgba_bg(input_image, "white")
+ single_image = simple_preprocess(input_image)
+
+ generator = torch.Generator(device="cuda").manual_seed(int(seed)) if seed >= 0 else None
+
+ rgb_pils = predict(
+ single_image,
+ generator=generator,
+ guidance_scale=guidance_scale,
+ width=256,
+ height=256,
+ num_inference_steps=30,
+ )
+
+ return rgb_pils, single_image
diff --git a/app/custom_models/normal_prediction.py b/app/custom_models/normal_prediction.py
new file mode 100644
index 0000000000000000000000000000000000000000..302357e35c6b303429bdbbd7dd57cd8cd879770a
--- /dev/null
+++ b/app/custom_models/normal_prediction.py
@@ -0,0 +1,26 @@
+import sys
+from PIL import Image
+from app.utils import rgba_to_rgb, simple_remove
+from app.custom_models.utils import load_pipeline
+from scripts.utils import rotate_normals_torch
+from scripts.all_typing import *
+
+training_config = "app/custom_models/image2image-objaverseF-rgb2normal.yaml"
+checkpoint_path = "ckpt/image2normal/unet_state_dict.pth"
+trainer, pipeline = load_pipeline(training_config, checkpoint_path)
+pipeline.enable_model_cpu_offload()
+
+def predict_normals(image: List[Image.Image], guidance_scale=2., do_rotate=True, num_inference_steps=30, **kwargs):
+ img_list = image if isinstance(image, list) else [image]
+ img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list]
+ images = trainer.pipeline_forward(
+ pipeline=pipeline,
+ image=img_list,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ **kwargs
+ ).images
+ images = simple_remove(images)
+ if do_rotate and len(images) > 1:
+ images = rotate_normals_torch(images, return_types='pil')
+ return images
\ No newline at end of file
diff --git a/app/custom_models/utils.py b/app/custom_models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dff2824ec7fe5926fa4dfc0f8b73fab8abf995e
--- /dev/null
+++ b/app/custom_models/utils.py
@@ -0,0 +1,75 @@
+import torch
+from typing import List
+from dataclasses import dataclass
+from app.utils import rgba_to_rgb
+from custum_3d_diffusion.trainings.config_classes import ExprimentConfig, TrainerSubConfig
+from custum_3d_diffusion import modules
+from custum_3d_diffusion.custum_modules.unifield_processor import AttnConfig, ConfigurableUNet2DConditionModel
+from custum_3d_diffusion.trainings.base import BasicTrainer
+from custum_3d_diffusion.trainings.utils import load_config
+
+
+@dataclass
+class FakeAccelerator:
+ device: torch.device = torch.device("cuda")
+
+
+def init_trainers(cfg_path: str, weight_dtype: torch.dtype, extras: dict):
+ accelerator = FakeAccelerator()
+ cfg: ExprimentConfig = load_config(ExprimentConfig, cfg_path, extras)
+ init_config: AttnConfig = load_config(AttnConfig, cfg.init_config)
+ configurable_unet = ConfigurableUNet2DConditionModel(init_config, weight_dtype)
+ configurable_unet.enable_xformers_memory_efficient_attention()
+ trainer_cfgs: List[TrainerSubConfig] = [load_config(TrainerSubConfig, trainer) for trainer in cfg.trainers]
+ trainers: List[BasicTrainer] = [modules.find(trainer.trainer_type)(accelerator, None, configurable_unet, trainer.trainer, weight_dtype, i) for i, trainer in enumerate(trainer_cfgs)]
+ return trainers, configurable_unet
+
+from app.utils import make_image_grid, split_image
+def process_image(function, img, guidance_scale=2., merged_image=False, remove_bg=True):
+ from rembg import remove
+ if remove_bg:
+ img = remove(img)
+ img = rgba_to_rgb(img)
+ if merged_image:
+ img = split_image(img, rows=2)
+ images = function(
+ image=img,
+ guidance_scale=guidance_scale,
+ )
+ if len(images) > 1:
+ return make_image_grid(images, rows=2)
+ else:
+ return images[0]
+
+
+def process_text(trainer, pipeline, img, guidance_scale=2.):
+ pipeline.cfg.validation_prompts = [img]
+ titles, images = trainer.batched_validation_forward(pipeline, guidance_scale=[guidance_scale])
+ return images[0]
+
+
+def load_pipeline(config_path, ckpt_path, pipeline_filter=lambda x: True, weight_dtype = torch.bfloat16):
+ training_config = config_path
+ load_from_checkpoint = ckpt_path
+ extras = []
+ device = "cuda"
+ trainers, configurable_unet = init_trainers(training_config, weight_dtype, extras)
+ shared_modules = dict()
+ for trainer in trainers:
+ shared_modules = trainer.init_shared_modules(shared_modules)
+
+ if load_from_checkpoint is not None:
+ state_dict = torch.load(load_from_checkpoint)
+ configurable_unet.unet.load_state_dict(state_dict, strict=False)
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ configurable_unet.unet.to(device, dtype=weight_dtype)
+
+ pipeline = None
+ trainer_out = None
+ for trainer in trainers:
+ if pipeline_filter(trainer.cfg.trainer_name):
+ pipeline = trainer.construct_pipeline(shared_modules, configurable_unet.unet)
+ pipeline.set_progress_bar_config(disable=False)
+ trainer_out = trainer
+ pipeline = pipeline.to(device)
+ return trainer_out, pipeline
\ No newline at end of file
diff --git a/app/examples/Groot.png b/app/examples/Groot.png
new file mode 100644
index 0000000000000000000000000000000000000000..2aaf0fb631fe91c062758d8c801fcdbbf44c7e79
--- /dev/null
+++ b/app/examples/Groot.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e9096d048ec8deb3673765c577c7030118a75fc87d3da08cec657f66dfd22479
+size 777998
diff --git a/app/examples/aaa.png b/app/examples/aaa.png
new file mode 100644
index 0000000000000000000000000000000000000000..d449a0ade138a7e09b733baa560e7e3f76d713a8
--- /dev/null
+++ b/app/examples/aaa.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0733f0c5ed507e3fc0a9f921c1b078e7a66526335ee8efee61e919233a05a1c1
+size 903027
diff --git a/app/examples/abma.png b/app/examples/abma.png
new file mode 100644
index 0000000000000000000000000000000000000000..a62eee4708ab26b24323ebef9825b9b28ec85918
--- /dev/null
+++ b/app/examples/abma.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:24640851ccf40f2e61313c81e702abffe2361f1c5a1ab6e5b46f328daba103b3
+size 93457
diff --git a/app/examples/akun.png b/app/examples/akun.png
new file mode 100644
index 0000000000000000000000000000000000000000..0d09d281c775aceb8fe1a9e809c38c2aeec8106d
--- /dev/null
+++ b/app/examples/akun.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b60404d448f09a3c11147f5d9e0e0544f0c2d4473425f110ded783cebf9c1f76
+size 181112
diff --git a/app/examples/anya.png b/app/examples/anya.png
new file mode 100644
index 0000000000000000000000000000000000000000..78c2802edcc25dc2e0be58d419f642a2323ffc28
--- /dev/null
+++ b/app/examples/anya.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eb2ae59e3bb9c028f12c6c587cae7219c389df4593379c74211a6c643cf0ffa7
+size 611788
diff --git a/app/examples/bag.png b/app/examples/bag.png
new file mode 100644
index 0000000000000000000000000000000000000000..e91e10cc662404fdfddc5d8b5df7f68c7028e31c
--- /dev/null
+++ b/app/examples/bag.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ac798ea1f112091c04f5bdfa47c490806fb433a02fe17758aa1f8c55cd64b66e
+size 1544762
diff --git a/app/examples/generated_1715761545_frame0.png b/app/examples/generated_1715761545_frame0.png
new file mode 100644
index 0000000000000000000000000000000000000000..a74121776947374eb8b3c0ff95afbe92472d2e3c
--- /dev/null
+++ b/app/examples/generated_1715761545_frame0.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ff813fe203a97a916bc73fa2bb61229c6c81884484cee1da53ff131093780636
+size 207964
diff --git a/app/examples/generated_1715762357_frame0.png b/app/examples/generated_1715762357_frame0.png
new file mode 100644
index 0000000000000000000000000000000000000000..061506d4a4c24f9021bbb9a4941c1347087799a5
--- /dev/null
+++ b/app/examples/generated_1715762357_frame0.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4f211e298d5e6ffc2fc7d8ad5133e81b471d13ab6398931e8386ea9698021b4b
+size 234892
diff --git a/app/examples/generated_1715763329_frame0.png b/app/examples/generated_1715763329_frame0.png
new file mode 100644
index 0000000000000000000000000000000000000000..5fa5682f4b3c476424cb331df7e5ffbed76be9fc
--- /dev/null
+++ b/app/examples/generated_1715763329_frame0.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e86aee7707d9870e1f56a24be9c52ff42048d4f45ed39d52e86f293336189580
+size 181983
diff --git a/app/examples/hatsune_miku.png b/app/examples/hatsune_miku.png
new file mode 100644
index 0000000000000000000000000000000000000000..6b8793cfb5b0f884f11cf720b343fb6ef0a7d77f
--- /dev/null
+++ b/app/examples/hatsune_miku.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fbb6285c5a9a670bdee0992c6db2e43b51c584f3adb052d89136000b52eedc97
+size 96183
diff --git a/app/examples/princess-large.png b/app/examples/princess-large.png
new file mode 100644
index 0000000000000000000000000000000000000000..58b8b39b4ea45cf289e597c257f0f58fc373f0da
--- /dev/null
+++ b/app/examples/princess-large.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:203fd1fef34720656e51d27b1bfdc8c0a082a9fbbf48f3100039a63dcc59fd57
+size 65470
diff --git a/app/examples/shoe.png b/app/examples/shoe.png
new file mode 100644
index 0000000000000000000000000000000000000000..4ffa98e323681d5b9569dd90ddfd35d3f647cad8
--- /dev/null
+++ b/app/examples/shoe.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2b3798b58377246626b0ff7d38fd0a5ff028399b3e5b9b53b92785707a3ca081
+size 248618
diff --git a/app/gradio_3dgen.py b/app/gradio_3dgen.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2113a773e36884e185550fa036f5ba4c72611c3
--- /dev/null
+++ b/app/gradio_3dgen.py
@@ -0,0 +1,71 @@
+import os
+import gradio as gr
+from PIL import Image
+from pytorch3d.structures import Meshes
+from app.utils import clean_up
+from app.custom_models.mvimg_prediction import run_mvprediction
+from app.custom_models.normal_prediction import predict_normals
+from scripts.refine_lr_to_sr import run_sr_fast
+from scripts.utils import save_glb_and_video
+from scripts.multiview_inference import geo_reconstruct
+
+def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"):
+ if preview_img is None:
+ raise gr.Error("preview_img is none")
+ if isinstance(preview_img, str):
+ preview_img = Image.open(preview_img)
+
+ if preview_img.size[0] <= 512:
+ preview_img = run_sr_fast([preview_img])[0]
+ rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=int(seed)) # 6s
+ new_meshes = geo_reconstruct(rgb_pils, None, front_pil, do_refine=do_refine, predict_normal=True, expansion_weight=expansion_weight, init_type=init_type)
+ vertices = new_meshes.verts_packed()
+ vertices = vertices / 2 * 1.35
+ vertices[..., [0, 2]] = - vertices[..., [0, 2]]
+ new_meshes = Meshes(verts=[vertices], faces=new_meshes.faces_list(), textures=new_meshes.textures)
+
+ ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=render_video)
+ return ret_mesh, video
+
+#######################################
+def create_ui(concurrency_id="wkl"):
+ with gr.Row():
+ with gr.Column(scale=2):
+ input_image = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
+
+ example_folder = os.path.join(os.path.dirname(__file__), "./examples")
+ example_fns = sorted([os.path.join(example_folder, example) for example in os.listdir(example_folder)])
+ gr.Examples(
+ examples=example_fns,
+ inputs=[input_image],
+ cache_examples=False,
+ label='Examples (click one of the images below to start)',
+ examples_per_page=12
+ )
+
+
+ with gr.Column(scale=3):
+ # export mesh display
+ output_mesh = gr.Model3D(value=None, label="Mesh Model", show_label=True, height=320)
+ output_video = gr.Video(label="Preview", show_label=True, show_share_button=True, height=320, visible=False)
+
+ input_processing = gr.Checkbox(
+ value=True,
+ label='Remove Background',
+ visible=True,
+ )
+ do_refine = gr.Checkbox(value=True, label="Refine Multiview Details", visible=False)
+ expansion_weight = gr.Slider(minimum=-1., maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False)
+ init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh Initialization", value="std", visible=False)
+ setable_seed = gr.Slider(-1, 1000000000, -1, step=1, visible=True, label="Seed")
+ render_video = gr.Checkbox(value=False, visible=False, label="generate video")
+ fullrunv2_btn = gr.Button('Generate 3D', interactive=True)
+
+ fullrunv2_btn.click(
+ fn = generate3dv2,
+ inputs=[input_image, input_processing, setable_seed, render_video, do_refine, expansion_weight, init_type],
+ outputs=[output_mesh, output_video],
+ concurrency_id=concurrency_id,
+ api_name="generate3dv2",
+ ).success(clean_up, api_name=False)
+ return input_image
diff --git a/app/gradio_3dgen_steps.py b/app/gradio_3dgen_steps.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dc91ada685a3b8b789770323b2bdfd7ca4fd3ac
--- /dev/null
+++ b/app/gradio_3dgen_steps.py
@@ -0,0 +1,87 @@
+import gradio as gr
+from PIL import Image
+
+from app.custom_models.mvimg_prediction import run_mvprediction
+from app.utils import make_image_grid, split_image
+from scripts.utils import save_glb_and_video
+
+def concept_to_multiview(preview_img, input_processing, seed, guidance=1.):
+ seed = int(seed)
+ if preview_img is None:
+ raise gr.Error("preview_img is none.")
+ if isinstance(preview_img, str):
+ preview_img = Image.open(preview_img)
+
+ rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=seed, guidance_scale=guidance)
+ rgb_pil = make_image_grid(rgb_pils, rows=2)
+ return rgb_pil, front_pil
+
+def concept_to_multiview_ui(concurrency_id="wkl"):
+ with gr.Row():
+ with gr.Column(scale=2):
+ preview_img = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
+ input_processing = gr.Checkbox(
+ value=True,
+ label='Remove Background',
+ )
+ seed = gr.Slider(minimum=-1, maximum=1000000000, value=-1, step=1.0, label="seed")
+ guidance = gr.Slider(minimum=1.0, maximum=5.0, value=1.0, label="Guidance Scale", step=0.5)
+ run_btn = gr.Button('Generate Multiview', interactive=True)
+ with gr.Column(scale=3):
+ # export mesh display
+ output_rgb = gr.Image(type='pil', label="RGB", show_label=True)
+ output_front = gr.Image(type='pil', image_mode='RGBA', label="Frontview", show_label=True)
+ run_btn.click(
+ fn = concept_to_multiview,
+ inputs=[preview_img, input_processing, seed, guidance],
+ outputs=[output_rgb, output_front],
+ concurrency_id=concurrency_id,
+ api_name=False,
+ )
+ return output_rgb, output_front
+
+from app.custom_models.normal_prediction import predict_normals
+from scripts.multiview_inference import geo_reconstruct
+def multiview_to_mesh_v2(rgb_pil, normal_pil, front_pil, do_refine=False, expansion_weight=0.1, init_type="std"):
+ rgb_pils = split_image(rgb_pil, rows=2)
+ if normal_pil is not None:
+ normal_pil = split_image(normal_pil, rows=2)
+ if front_pil is None:
+ front_pil = rgb_pils[0]
+ new_meshes = geo_reconstruct(rgb_pils, normal_pil, front_pil, do_refine=do_refine, predict_normal=normal_pil is None, expansion_weight=expansion_weight, init_type=init_type)
+ ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=False)
+ return ret_mesh
+
+def new_multiview_to_mesh_ui(concurrency_id="wkl"):
+ with gr.Row():
+ with gr.Column(scale=2):
+ rgb_pil = gr.Image(type='pil', image_mode='RGB', label='RGB')
+ front_pil = gr.Image(type='pil', image_mode='RGBA', label='Frontview(Optinal)')
+ normal_pil = gr.Image(type='pil', image_mode='RGBA', label='Normal(Optinal)')
+ do_refine = gr.Checkbox(
+ value=False,
+ label='Refine rgb',
+ visible=False,
+ )
+ expansion_weight = gr.Slider(minimum=-1.0, maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False)
+ init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh initialization", value="std", visible=False)
+ run_btn = gr.Button('Generate 3D', interactive=True)
+ with gr.Column(scale=3):
+ # export mesh display
+ output_mesh = gr.Model3D(value=None, label="mesh model", show_label=True)
+ run_btn.click(
+ fn = multiview_to_mesh_v2,
+ inputs=[rgb_pil, normal_pil, front_pil, do_refine, expansion_weight, init_type],
+ outputs=[output_mesh],
+ concurrency_id=concurrency_id,
+ api_name="multiview_to_mesh",
+ )
+ return rgb_pil, front_pil, output_mesh
+
+
+#######################################
+def create_step_ui(concurrency_id="wkl"):
+ with gr.Tab(label="3D:concept_to_multiview"):
+ concept_to_multiview_ui(concurrency_id)
+ with gr.Tab(label="3D:new_multiview_to_mesh"):
+ new_multiview_to_mesh_ui(concurrency_id)
diff --git a/app/gradio_local.py b/app/gradio_local.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e6317b1c5c4f5099f4969e9cbec6b3736c2beee
--- /dev/null
+++ b/app/gradio_local.py
@@ -0,0 +1,76 @@
+if __name__ == "__main__":
+ import os
+ import sys
+ sys.path.append(os.curdir)
+ if 'CUDA_VISIBLE_DEVICES' not in os.environ:
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+ os.environ['TRANSFORMERS_OFFLINE']='0'
+ os.environ['DIFFUSERS_OFFLINE']='0'
+ os.environ['HF_HUB_OFFLINE']='0'
+ os.environ['GRADIO_ANALYTICS_ENABLED']='False'
+ os.environ['HF_ENDPOINT']='https://hf-mirror.com'
+ import torch
+ torch.set_float32_matmul_precision('medium')
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.set_grad_enabled(False)
+
+import gradio as gr
+import argparse
+
+from app.gradio_3dgen import create_ui as create_3d_ui
+# from app.gradio_3dgen_steps import create_step_ui
+from app.all_models import model_zoo
+
+
+_TITLE = '''Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image'''
+_DESCRIPTION = '''
+[Project page](https://wukailu.github.io/Unique3D/)
+
+* High-fidelity and diverse textured meshes generated by Unique3D from single-view images.
+
+* The demo is still under construction, and more features are expected to be implemented soon.
+'''
+
+def launch(
+ port,
+ listen=False,
+ share=False,
+ gradio_root="",
+):
+ model_zoo.init_models()
+
+ with gr.Blocks(
+ title=_TITLE,
+ theme=gr.themes.Monochrome(),
+ ) as demo:
+ with gr.Row():
+ with gr.Column(scale=1):
+ gr.Markdown('# ' + _TITLE)
+ gr.Markdown(_DESCRIPTION)
+ create_3d_ui("wkl")
+
+ launch_args = {}
+ if listen:
+ launch_args["server_name"] = "0.0.0.0"
+
+ demo.queue(default_concurrency_limit=1).launch(
+ server_port=None if port == 0 else port,
+ share=share,
+ root_path=gradio_root if gradio_root != "" else None, # "/myapp"
+ **launch_args,
+ )
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ args, extra = parser.parse_known_args()
+ parser.add_argument("--listen", action="store_true")
+ parser.add_argument("--port", type=int, default=0)
+ parser.add_argument("--share", action="store_true")
+ parser.add_argument("--gradio_root", default="")
+ args = parser.parse_args()
+ launch(
+ args.port,
+ listen=args.listen,
+ share=args.share,
+ gradio_root=args.gradio_root,
+ )
\ No newline at end of file
diff --git a/app/utils.py b/app/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5385bfb256015dd892682ca6b8936c6632a453bf
--- /dev/null
+++ b/app/utils.py
@@ -0,0 +1,112 @@
+import torch
+import numpy as np
+from PIL import Image
+import gc
+import numpy as np
+import numpy as np
+from PIL import Image
+from scripts.refine_lr_to_sr import run_sr_fast
+
+GRADIO_CACHE = "/tmp/gradio/"
+
+def clean_up():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+def remove_color(arr):
+ if arr.shape[-1] == 4:
+ arr = arr[..., :3]
+ # calc diffs
+ base = arr[0, 0]
+ diffs = np.abs(arr.astype(np.int32) - base.astype(np.int32)).sum(axis=-1)
+ alpha = (diffs <= 80)
+
+ arr[alpha] = 255
+ alpha = ~alpha
+ arr = np.concatenate([arr, alpha[..., None].astype(np.int32) * 255], axis=-1)
+ return arr
+
+def simple_remove(imgs, run_sr=True):
+ """Only works for normal"""
+ if not isinstance(imgs, list):
+ imgs = [imgs]
+ single_input = True
+ else:
+ single_input = False
+ if run_sr:
+ imgs = run_sr_fast(imgs)
+ rets = []
+ for img in imgs:
+ arr = np.array(img)
+ arr = remove_color(arr)
+ rets.append(Image.fromarray(arr.astype(np.uint8)))
+ if single_input:
+ return rets[0]
+ return rets
+
+def rgba_to_rgb(rgba: Image.Image, bkgd="WHITE"):
+ new_image = Image.new("RGBA", rgba.size, bkgd)
+ new_image.paste(rgba, (0, 0), rgba)
+ new_image = new_image.convert('RGB')
+ return new_image
+
+def change_rgba_bg(rgba: Image.Image, bkgd="WHITE"):
+ rgb_white = rgba_to_rgb(rgba, bkgd)
+ new_rgba = Image.fromarray(np.concatenate([np.array(rgb_white), np.array(rgba)[:, :, 3:4]], axis=-1))
+ return new_rgba
+
+def split_image(image, rows=None, cols=None):
+ """
+ inverse function of make_image_grid
+ """
+ # image is in square
+ if rows is None and cols is None:
+ # image.size [W, H]
+ rows = 1
+ cols = image.size[0] // image.size[1]
+ assert cols * image.size[1] == image.size[0]
+ subimg_size = image.size[1]
+ elif rows is None:
+ subimg_size = image.size[0] // cols
+ rows = image.size[1] // subimg_size
+ assert rows * subimg_size == image.size[1]
+ elif cols is None:
+ subimg_size = image.size[1] // rows
+ cols = image.size[0] // subimg_size
+ assert cols * subimg_size == image.size[0]
+ else:
+ subimg_size = image.size[1] // rows
+ assert cols * subimg_size == image.size[0]
+ subimgs = []
+ for i in range(rows):
+ for j in range(cols):
+ subimg = image.crop((j*subimg_size, i*subimg_size, (j+1)*subimg_size, (i+1)*subimg_size))
+ subimgs.append(subimg)
+ return subimgs
+
+def make_image_grid(images, rows=None, cols=None, resize=None):
+ if rows is None and cols is None:
+ rows = 1
+ cols = len(images)
+ if rows is None:
+ rows = len(images) // cols
+ if len(images) % cols != 0:
+ rows += 1
+ if cols is None:
+ cols = len(images) // rows
+ if len(images) % rows != 0:
+ cols += 1
+ total_imgs = rows * cols
+ if total_imgs > len(images):
+ images += [Image.new(images[0].mode, images[0].size) for _ in range(total_imgs - len(images))]
+
+ if resize is not None:
+ images = [img.resize((resize, resize)) for img in images]
+
+ w, h = images[0].size
+ grid = Image.new(images[0].mode, size=(cols * w, rows * h))
+
+ for i, img in enumerate(images):
+ grid.paste(img, box=(i % cols * w, i // cols * h))
+ return grid
+
diff --git a/assets/teaser.jpg b/assets/teaser.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..53a97e29b8fbd5719199a07bbca041f1faebb89e
Binary files /dev/null and b/assets/teaser.jpg differ
diff --git a/ckpt/controlnet-tile/config.json b/ckpt/controlnet-tile/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..754ad3afbb1f2cc8de0fad9ad5e337ecf0eb2681
--- /dev/null
+++ b/ckpt/controlnet-tile/config.json
@@ -0,0 +1,52 @@
+{
+ "_class_name": "ControlNetModel",
+ "_diffusers_version": "0.27.2",
+ "_name_or_path": "lllyasviel/control_v11f1e_sd15_tile",
+ "act_fn": "silu",
+ "addition_embed_type": null,
+ "addition_embed_type_num_heads": 64,
+ "addition_time_embed_dim": null,
+ "attention_head_dim": 8,
+ "block_out_channels": [
+ 320,
+ 640,
+ 1280,
+ 1280
+ ],
+ "class_embed_type": null,
+ "conditioning_channels": 3,
+ "conditioning_embedding_out_channels": [
+ 16,
+ 32,
+ 96,
+ 256
+ ],
+ "controlnet_conditioning_channel_order": "rgb",
+ "cross_attention_dim": 768,
+ "down_block_types": [
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D"
+ ],
+ "downsample_padding": 1,
+ "encoder_hid_dim": null,
+ "encoder_hid_dim_type": null,
+ "flip_sin_to_cos": true,
+ "freq_shift": 0,
+ "global_pool_conditions": false,
+ "in_channels": 4,
+ "layers_per_block": 2,
+ "mid_block_scale_factor": 1,
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
+ "norm_eps": 1e-05,
+ "norm_num_groups": 32,
+ "num_attention_heads": null,
+ "num_class_embeds": null,
+ "only_cross_attention": false,
+ "projection_class_embeddings_input_dim": null,
+ "resnet_time_scale_shift": "default",
+ "transformer_layers_per_block": 1,
+ "upcast_attention": false,
+ "use_linear_projection": false
+}
diff --git a/ckpt/controlnet-tile/diffusion_pytorch_model.safetensors b/ckpt/controlnet-tile/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..7f56d24eefd260afd7a0877e4a37b76fa7c0d69a
--- /dev/null
+++ b/ckpt/controlnet-tile/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:845d3845053912728cd1453029a0ef87d3c0a3082a083ba393f36eaa5fb0e218
+size 1445157120
diff --git a/ckpt/image2normal/feature_extractor/preprocessor_config.json b/ckpt/image2normal/feature_extractor/preprocessor_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..c7f2d86b29c96c6db8ae37b12cc51af1b9c7d195
--- /dev/null
+++ b/ckpt/image2normal/feature_extractor/preprocessor_config.json
@@ -0,0 +1,44 @@
+{
+ "_valid_processor_keys": [
+ "images",
+ "do_resize",
+ "size",
+ "resample",
+ "do_center_crop",
+ "crop_size",
+ "do_rescale",
+ "rescale_factor",
+ "do_normalize",
+ "image_mean",
+ "image_std",
+ "do_convert_rgb",
+ "return_tensors",
+ "data_format",
+ "input_data_format"
+ ],
+ "crop_size": {
+ "height": 224,
+ "width": 224
+ },
+ "do_center_crop": true,
+ "do_convert_rgb": true,
+ "do_normalize": true,
+ "do_rescale": true,
+ "do_resize": true,
+ "image_mean": [
+ 0.48145466,
+ 0.4578275,
+ 0.40821073
+ ],
+ "image_processor_type": "CLIPImageProcessor",
+ "image_std": [
+ 0.26862954,
+ 0.26130258,
+ 0.27577711
+ ],
+ "resample": 3,
+ "rescale_factor": 0.00392156862745098,
+ "size": {
+ "shortest_edge": 224
+ }
+}
diff --git a/ckpt/image2normal/image_encoder/config.json b/ckpt/image2normal/image_encoder/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..2a2fb07882916f33e8e2e762c1080541929acdcb
--- /dev/null
+++ b/ckpt/image2normal/image_encoder/config.json
@@ -0,0 +1,23 @@
+{
+ "_name_or_path": "lambdalabs/sd-image-variations-diffusers",
+ "architectures": [
+ "CLIPVisionModelWithProjection"
+ ],
+ "attention_dropout": 0.0,
+ "dropout": 0.0,
+ "hidden_act": "quick_gelu",
+ "hidden_size": 1024,
+ "image_size": 224,
+ "initializer_factor": 1.0,
+ "initializer_range": 0.02,
+ "intermediate_size": 4096,
+ "layer_norm_eps": 1e-05,
+ "model_type": "clip_vision_model",
+ "num_attention_heads": 16,
+ "num_channels": 3,
+ "num_hidden_layers": 24,
+ "patch_size": 14,
+ "projection_dim": 768,
+ "torch_dtype": "bfloat16",
+ "transformers_version": "4.39.3"
+}
diff --git a/ckpt/image2normal/image_encoder/model.safetensors b/ckpt/image2normal/image_encoder/model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..0ccebd91eb644560fa5b3623167df0cf2e36f890
--- /dev/null
+++ b/ckpt/image2normal/image_encoder/model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e4b33d864f89a793357a768cb07d0dc18d6a14e6664f4110a0d535ca9ba78da8
+size 607980488
diff --git a/ckpt/image2normal/model_index.json b/ckpt/image2normal/model_index.json
new file mode 100644
index 0000000000000000000000000000000000000000..5a6a1609d2289f834c83f05054317963598bb474
--- /dev/null
+++ b/ckpt/image2normal/model_index.json
@@ -0,0 +1,31 @@
+{
+ "_class_name": "StableDiffusionImageCustomPipeline",
+ "_diffusers_version": "0.27.2",
+ "_name_or_path": "lambdalabs/sd-image-variations-diffusers",
+ "feature_extractor": [
+ "transformers",
+ "CLIPImageProcessor"
+ ],
+ "image_encoder": [
+ "transformers",
+ "CLIPVisionModelWithProjection"
+ ],
+ "noisy_cond_latents": false,
+ "requires_safety_checker": true,
+ "safety_checker": [
+ null,
+ null
+ ],
+ "scheduler": [
+ "diffusers",
+ "EulerAncestralDiscreteScheduler"
+ ],
+ "unet": [
+ "diffusers",
+ "UNet2DConditionModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL"
+ ]
+}
diff --git a/ckpt/image2normal/scheduler/scheduler_config.json b/ckpt/image2normal/scheduler/scheduler_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..35d882be7ab97389295e71db656abbe07bf61ae1
--- /dev/null
+++ b/ckpt/image2normal/scheduler/scheduler_config.json
@@ -0,0 +1,16 @@
+{
+ "_class_name": "EulerAncestralDiscreteScheduler",
+ "_diffusers_version": "0.27.2",
+ "beta_end": 0.012,
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.00085,
+ "clip_sample": false,
+ "num_train_timesteps": 1000,
+ "prediction_type": "epsilon",
+ "rescale_betas_zero_snr": false,
+ "set_alpha_to_one": false,
+ "skip_prk_steps": true,
+ "steps_offset": 1,
+ "timestep_spacing": "linspace",
+ "trained_betas": null
+}
diff --git a/ckpt/image2normal/unet/config.json b/ckpt/image2normal/unet/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..0212cd93d1eada86bed425d88124ba1760d7d10c
--- /dev/null
+++ b/ckpt/image2normal/unet/config.json
@@ -0,0 +1,68 @@
+{
+ "_class_name": "UnifieldWrappedUNet",
+ "_diffusers_version": "0.27.2",
+ "_name_or_path": "lambdalabs/sd-image-variations-diffusers",
+ "act_fn": "silu",
+ "addition_embed_type": null,
+ "addition_embed_type_num_heads": 64,
+ "addition_time_embed_dim": null,
+ "attention_head_dim": 8,
+ "attention_type": "default",
+ "block_out_channels": [
+ 320,
+ 640,
+ 1280,
+ 1280
+ ],
+ "center_input_sample": false,
+ "class_embed_type": null,
+ "class_embeddings_concat": false,
+ "conv_in_kernel": 3,
+ "conv_out_kernel": 3,
+ "cross_attention_dim": 768,
+ "cross_attention_norm": null,
+ "down_block_types": [
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D"
+ ],
+ "downsample_padding": 1,
+ "dropout": 0.0,
+ "dual_cross_attention": false,
+ "encoder_hid_dim": null,
+ "encoder_hid_dim_type": null,
+ "flip_sin_to_cos": true,
+ "freq_shift": 0,
+ "in_channels": 4,
+ "layers_per_block": 2,
+ "mid_block_only_cross_attention": null,
+ "mid_block_scale_factor": 1,
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
+ "norm_eps": 1e-05,
+ "norm_num_groups": 32,
+ "num_attention_heads": null,
+ "num_class_embeds": null,
+ "only_cross_attention": false,
+ "out_channels": 4,
+ "projection_class_embeddings_input_dim": null,
+ "resnet_out_scale_factor": 1.0,
+ "resnet_skip_time_act": false,
+ "resnet_time_scale_shift": "default",
+ "reverse_transformer_layers_per_block": null,
+ "sample_size": 64,
+ "time_cond_proj_dim": null,
+ "time_embedding_act_fn": null,
+ "time_embedding_dim": null,
+ "time_embedding_type": "positional",
+ "timestep_post_act": null,
+ "transformer_layers_per_block": 1,
+ "up_block_types": [
+ "UpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D"
+ ],
+ "upcast_attention": false,
+ "use_linear_projection": false
+}
diff --git a/ckpt/image2normal/unet/diffusion_pytorch_model.safetensors b/ckpt/image2normal/unet/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..fcd2824e099b9951b853cafe864693c968b1e36f
--- /dev/null
+++ b/ckpt/image2normal/unet/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f5cbaf1d56619345ce78de8cfbb20d94923b3305a364bf6a5b2a2cc422d4b701
+size 3537503456
diff --git a/ckpt/image2normal/unet_state_dict.pth b/ckpt/image2normal/unet_state_dict.pth
new file mode 100644
index 0000000000000000000000000000000000000000..551f25bdb2bdbf46ee643c73597600f811870331
--- /dev/null
+++ b/ckpt/image2normal/unet_state_dict.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8df80d09e953d338aa6d8decd0351c5045f52ec6e2645eee2027ccb8792c8ed8
+size 3537964654
diff --git a/ckpt/image2normal/vae/config.json b/ckpt/image2normal/vae/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..5dac9329af91a055d118281dfae1e6ba77b91762
--- /dev/null
+++ b/ckpt/image2normal/vae/config.json
@@ -0,0 +1,34 @@
+{
+ "_class_name": "AutoencoderKL",
+ "_diffusers_version": "0.27.2",
+ "_name_or_path": "lambdalabs/sd-image-variations-diffusers",
+ "act_fn": "silu",
+ "block_out_channels": [
+ 128,
+ 256,
+ 512,
+ 512
+ ],
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D"
+ ],
+ "force_upcast": true,
+ "in_channels": 3,
+ "latent_channels": 4,
+ "latents_mean": null,
+ "latents_std": null,
+ "layers_per_block": 2,
+ "norm_num_groups": 32,
+ "out_channels": 3,
+ "sample_size": 256,
+ "scaling_factor": 0.18215,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D"
+ ]
+}
diff --git a/ckpt/image2normal/vae/diffusion_pytorch_model.safetensors b/ckpt/image2normal/vae/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..ae4bb53593b4381098d3afb99c5ca2b253c0d86b
--- /dev/null
+++ b/ckpt/image2normal/vae/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8d0c34f57abe50f323040f2366c8e22b941068dcdf53c8eb1d6fafb838afecb7
+size 167335590
diff --git a/ckpt/img2mvimg/feature_extractor/preprocessor_config.json b/ckpt/img2mvimg/feature_extractor/preprocessor_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..c7f2d86b29c96c6db8ae37b12cc51af1b9c7d195
--- /dev/null
+++ b/ckpt/img2mvimg/feature_extractor/preprocessor_config.json
@@ -0,0 +1,44 @@
+{
+ "_valid_processor_keys": [
+ "images",
+ "do_resize",
+ "size",
+ "resample",
+ "do_center_crop",
+ "crop_size",
+ "do_rescale",
+ "rescale_factor",
+ "do_normalize",
+ "image_mean",
+ "image_std",
+ "do_convert_rgb",
+ "return_tensors",
+ "data_format",
+ "input_data_format"
+ ],
+ "crop_size": {
+ "height": 224,
+ "width": 224
+ },
+ "do_center_crop": true,
+ "do_convert_rgb": true,
+ "do_normalize": true,
+ "do_rescale": true,
+ "do_resize": true,
+ "image_mean": [
+ 0.48145466,
+ 0.4578275,
+ 0.40821073
+ ],
+ "image_processor_type": "CLIPImageProcessor",
+ "image_std": [
+ 0.26862954,
+ 0.26130258,
+ 0.27577711
+ ],
+ "resample": 3,
+ "rescale_factor": 0.00392156862745098,
+ "size": {
+ "shortest_edge": 224
+ }
+}
diff --git a/ckpt/img2mvimg/image_encoder/config.json b/ckpt/img2mvimg/image_encoder/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..c6cfa1872a8bfde94e4bce13dc66c94d01966422
--- /dev/null
+++ b/ckpt/img2mvimg/image_encoder/config.json
@@ -0,0 +1,23 @@
+{
+ "_name_or_path": "lambdalabs/sd-image-variations-diffusers",
+ "architectures": [
+ "CLIPVisionModelWithProjection"
+ ],
+ "attention_dropout": 0.0,
+ "dropout": 0.0,
+ "hidden_act": "quick_gelu",
+ "hidden_size": 1024,
+ "image_size": 224,
+ "initializer_factor": 1.0,
+ "initializer_range": 0.02,
+ "intermediate_size": 4096,
+ "layer_norm_eps": 1e-05,
+ "model_type": "clip_vision_model",
+ "num_attention_heads": 16,
+ "num_channels": 3,
+ "num_hidden_layers": 24,
+ "patch_size": 14,
+ "projection_dim": 768,
+ "torch_dtype": "float32",
+ "transformers_version": "4.39.3"
+}
diff --git a/ckpt/img2mvimg/image_encoder/model.safetensors b/ckpt/img2mvimg/image_encoder/model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..18f4f46f7c10b48fda406d8f6c2092836290ceba
--- /dev/null
+++ b/ckpt/img2mvimg/image_encoder/model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:77b33d2a3a643650857672e880ccf73adbaf114fbbadec36d142ee9d48af7e20
+size 1215912728
diff --git a/ckpt/img2mvimg/model_index.json b/ckpt/img2mvimg/model_index.json
new file mode 100644
index 0000000000000000000000000000000000000000..afe0b226c23c02831871707fe366943e0e9310de
--- /dev/null
+++ b/ckpt/img2mvimg/model_index.json
@@ -0,0 +1,31 @@
+{
+ "_class_name": "StableDiffusionImage2MVCustomPipeline",
+ "_diffusers_version": "0.27.2",
+ "_name_or_path": "lambdalabs/sd-image-variations-diffusers",
+ "condition_offset": true,
+ "feature_extractor": [
+ "transformers",
+ "CLIPImageProcessor"
+ ],
+ "image_encoder": [
+ "transformers",
+ "CLIPVisionModelWithProjection"
+ ],
+ "requires_safety_checker": true,
+ "safety_checker": [
+ null,
+ null
+ ],
+ "scheduler": [
+ "diffusers",
+ "DDIMScheduler"
+ ],
+ "unet": [
+ "diffusers",
+ "UNet2DConditionModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL"
+ ]
+}
diff --git a/ckpt/img2mvimg/scheduler/scheduler_config.json b/ckpt/img2mvimg/scheduler/scheduler_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..788d42392ab8369f4f2a4899bcbf557eda772ecc
--- /dev/null
+++ b/ckpt/img2mvimg/scheduler/scheduler_config.json
@@ -0,0 +1,20 @@
+{
+ "_class_name": "DDIMScheduler",
+ "_diffusers_version": "0.27.2",
+ "beta_end": 0.012,
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.00085,
+ "clip_sample": false,
+ "clip_sample_range": 1.0,
+ "dynamic_thresholding_ratio": 0.995,
+ "num_train_timesteps": 1000,
+ "prediction_type": "epsilon",
+ "rescale_betas_zero_snr": false,
+ "sample_max_value": 1.0,
+ "set_alpha_to_one": false,
+ "skip_prk_steps": true,
+ "steps_offset": 1,
+ "thresholding": false,
+ "timestep_spacing": "leading",
+ "trained_betas": null
+}
diff --git a/ckpt/img2mvimg/unet/config.json b/ckpt/img2mvimg/unet/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..851f2e10f0879ca02b54f2fb63f5438ce7ef1c06
--- /dev/null
+++ b/ckpt/img2mvimg/unet/config.json
@@ -0,0 +1,68 @@
+{
+ "_class_name": "UnifieldWrappedUNet",
+ "_diffusers_version": "0.27.2",
+ "_name_or_path": "lambdalabs/sd-image-variations-diffusers",
+ "act_fn": "silu",
+ "addition_embed_type": null,
+ "addition_embed_type_num_heads": 64,
+ "addition_time_embed_dim": null,
+ "attention_head_dim": 8,
+ "attention_type": "default",
+ "block_out_channels": [
+ 320,
+ 640,
+ 1280,
+ 1280
+ ],
+ "center_input_sample": false,
+ "class_embed_type": null,
+ "class_embeddings_concat": false,
+ "conv_in_kernel": 3,
+ "conv_out_kernel": 3,
+ "cross_attention_dim": 768,
+ "cross_attention_norm": null,
+ "down_block_types": [
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D"
+ ],
+ "downsample_padding": 1,
+ "dropout": 0.0,
+ "dual_cross_attention": false,
+ "encoder_hid_dim": null,
+ "encoder_hid_dim_type": null,
+ "flip_sin_to_cos": true,
+ "freq_shift": 0,
+ "in_channels": 8,
+ "layers_per_block": 2,
+ "mid_block_only_cross_attention": null,
+ "mid_block_scale_factor": 1,
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
+ "norm_eps": 1e-05,
+ "norm_num_groups": 32,
+ "num_attention_heads": null,
+ "num_class_embeds": 8,
+ "only_cross_attention": false,
+ "out_channels": 4,
+ "projection_class_embeddings_input_dim": null,
+ "resnet_out_scale_factor": 1.0,
+ "resnet_skip_time_act": false,
+ "resnet_time_scale_shift": "default",
+ "reverse_transformer_layers_per_block": null,
+ "sample_size": 64,
+ "time_cond_proj_dim": null,
+ "time_embedding_act_fn": null,
+ "time_embedding_dim": null,
+ "time_embedding_type": "positional",
+ "timestep_post_act": null,
+ "transformer_layers_per_block": 1,
+ "up_block_types": [
+ "UpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D"
+ ],
+ "upcast_attention": false,
+ "use_linear_projection": false
+}
diff --git a/ckpt/img2mvimg/unet/diffusion_pytorch_model.safetensors b/ckpt/img2mvimg/unet/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..698ab80930420e263414db5eee9dac62f8001fdf
--- /dev/null
+++ b/ckpt/img2mvimg/unet/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:93a3b4e678efac0c997e76df465df13136a4b0f1732e534a1200fad9e04cd0f9
+size 3438254688
diff --git a/ckpt/img2mvimg/unet_state_dict.pth b/ckpt/img2mvimg/unet_state_dict.pth
new file mode 100644
index 0000000000000000000000000000000000000000..979b10289057f68db00a50905f68cd780bb1e5c5
--- /dev/null
+++ b/ckpt/img2mvimg/unet_state_dict.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0dff2fdba450af0e10c3a847ba66a530170be2e9b9c9f4c834483515e82738b5
+size 3438460972
diff --git a/ckpt/img2mvimg/vae/config.json b/ckpt/img2mvimg/vae/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..5dac9329af91a055d118281dfae1e6ba77b91762
--- /dev/null
+++ b/ckpt/img2mvimg/vae/config.json
@@ -0,0 +1,34 @@
+{
+ "_class_name": "AutoencoderKL",
+ "_diffusers_version": "0.27.2",
+ "_name_or_path": "lambdalabs/sd-image-variations-diffusers",
+ "act_fn": "silu",
+ "block_out_channels": [
+ 128,
+ 256,
+ 512,
+ 512
+ ],
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D"
+ ],
+ "force_upcast": true,
+ "in_channels": 3,
+ "latent_channels": 4,
+ "latents_mean": null,
+ "latents_std": null,
+ "layers_per_block": 2,
+ "norm_num_groups": 32,
+ "out_channels": 3,
+ "sample_size": 256,
+ "scaling_factor": 0.18215,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D"
+ ]
+}
diff --git a/ckpt/img2mvimg/vae/diffusion_pytorch_model.safetensors b/ckpt/img2mvimg/vae/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..d6fc2b1f7ae2b1f4f83c25812f819a17473f0c1a
--- /dev/null
+++ b/ckpt/img2mvimg/vae/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2aa1f43011b553a4cba7f37456465cdbd48aab7b54b9348b890e8058ea7683ec
+size 334643268
diff --git a/ckpt/realesrgan-x4.onnx b/ckpt/realesrgan-x4.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..92f8ebc05ec7159cf963de0022c6bc9a5cd9f9c8
--- /dev/null
+++ b/ckpt/realesrgan-x4.onnx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9bc5d0c85207adad8bca26286f0c0007f266f85e7aa7c454c565da9b5f3c940a
+size 67051617
diff --git a/ckpt/v1-inference.yaml b/ckpt/v1-inference.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d4effe569e897369918625f9d8be5603a0e6a0d6
--- /dev/null
+++ b/ckpt/v1-inference.yaml
@@ -0,0 +1,70 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+
+ scheduler_config: # 10000 warmup steps
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 10000 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
diff --git a/custum_3d_diffusion/custum_modules/attention_processors.py b/custum_3d_diffusion/custum_modules/attention_processors.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b01cc45f3e6fbd7fa23f5409dd00fa7222b7762
--- /dev/null
+++ b/custum_3d_diffusion/custum_modules/attention_processors.py
@@ -0,0 +1,385 @@
+from typing import Any, Dict, Optional
+import torch
+from diffusers.models.attention_processor import Attention
+
+def construct_pix2pix_attention(hidden_states_dim, norm_type="none"):
+ if norm_type == "layernorm":
+ norm = torch.nn.LayerNorm(hidden_states_dim)
+ else:
+ norm = torch.nn.Identity()
+ attention = Attention(
+ query_dim=hidden_states_dim,
+ heads=8,
+ dim_head=hidden_states_dim // 8,
+ bias=True,
+ )
+ # NOTE: xformers 0.22 does not support batchsize >= 4096
+ attention.xformers_not_supported = True # hacky solution
+ return norm, attention
+
+class ExtraAttnProc(torch.nn.Module):
+ def __init__(
+ self,
+ chained_proc,
+ enabled=False,
+ name=None,
+ mode='extract',
+ with_proj_in=False,
+ proj_in_dim=768,
+ target_dim=None,
+ pixel_wise_crosspond=False,
+ norm_type="none", # none or layernorm
+ crosspond_effect_on="all", # all or first
+ crosspond_chain_pos="parralle", # before or parralle or after
+ simple_3d=False,
+ views=4,
+ ) -> None:
+ super().__init__()
+ self.enabled = enabled
+ self.chained_proc = chained_proc
+ self.name = name
+ self.mode = mode
+ self.with_proj_in=with_proj_in
+ self.proj_in_dim = proj_in_dim
+ self.target_dim = target_dim or proj_in_dim
+ self.hidden_states_dim = self.target_dim
+ self.pixel_wise_crosspond = pixel_wise_crosspond
+ self.crosspond_effect_on = crosspond_effect_on
+ self.crosspond_chain_pos = crosspond_chain_pos
+ self.views = views
+ self.simple_3d = simple_3d
+ if self.with_proj_in and self.enabled:
+ self.in_linear = torch.nn.Linear(self.proj_in_dim, self.target_dim, bias=False)
+ if self.target_dim == self.proj_in_dim:
+ self.in_linear.weight.data = torch.eye(proj_in_dim)
+ else:
+ self.in_linear = None
+ if self.pixel_wise_crosspond and self.enabled:
+ self.crosspond_norm, self.crosspond_attention = construct_pix2pix_attention(self.hidden_states_dim, norm_type=norm_type)
+
+ def do_crosspond_attention(self, hidden_states: torch.FloatTensor, other_states: torch.FloatTensor):
+ hidden_states = self.crosspond_norm(hidden_states)
+
+ batch, L, D = hidden_states.shape
+ assert hidden_states.shape == other_states.shape, f"got {hidden_states.shape} and {other_states.shape}"
+ # to -> batch * L, 1, D
+ hidden_states = hidden_states.reshape(batch * L, 1, D)
+ other_states = other_states.reshape(batch * L, 1, D)
+ hidden_states_catted = other_states
+ hidden_states = self.crosspond_attention(
+ hidden_states,
+ encoder_hidden_states=hidden_states_catted,
+ )
+ return hidden_states.reshape(batch, L, D)
+
+ def __call__(
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,
+ ref_dict: dict = None, mode=None, **kwargs
+ ) -> Any:
+ if not self.enabled:
+ return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ assert ref_dict is not None
+ if (mode or self.mode) == 'extract':
+ ref_dict[self.name] = hidden_states
+ hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
+ if self.pixel_wise_crosspond and self.crosspond_chain_pos == "after":
+ ref_dict[self.name] = hidden_states1
+ return hidden_states1
+ elif (mode or self.mode) == 'inject':
+ ref_state = ref_dict.pop(self.name)
+ if self.with_proj_in:
+ ref_state = self.in_linear(ref_state)
+
+ B, L, D = ref_state.shape
+ if hidden_states.shape[0] == B:
+ modalities = 1
+ views = 1
+ else:
+ modalities = hidden_states.shape[0] // B // self.views
+ views = self.views
+ if self.pixel_wise_crosspond:
+ if self.crosspond_effect_on == "all":
+ ref_state = ref_state[:, None].expand(-1, modalities * views, -1, -1).reshape(-1, *ref_state.shape[-2:])
+
+ if self.crosspond_chain_pos == "before":
+ hidden_states = hidden_states + self.do_crosspond_attention(hidden_states, ref_state)
+
+ hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
+
+ if self.crosspond_chain_pos == "parralle":
+ hidden_states1 = hidden_states1 + self.do_crosspond_attention(hidden_states, ref_state)
+
+ if self.crosspond_chain_pos == "after":
+ hidden_states1 = hidden_states1 + self.do_crosspond_attention(hidden_states1, ref_state)
+ return hidden_states1
+ else:
+ assert self.crosspond_effect_on == "first"
+ # hidden_states [B * modalities * views, L, D]
+ # ref_state [B, L, D]
+ ref_state = ref_state[:, None].expand(-1, modalities, -1, -1).reshape(-1, ref_state.shape[-2], ref_state.shape[-1]) # [B * modalities, L, D]
+
+ def do_paritial_crosspond(hidden_states, ref_state):
+ first_view_hidden_states = hidden_states.view(-1, views, hidden_states.shape[1], hidden_states.shape[2])[:, 0] # [B * modalities, L, D]
+ hidden_states2 = self.do_crosspond_attention(first_view_hidden_states, ref_state) # [B * modalities, L, D]
+ hidden_states2_padded = torch.zeros_like(hidden_states).reshape(-1, views, hidden_states.shape[1], hidden_states.shape[2])
+ hidden_states2_padded[:, 0] = hidden_states2
+ hidden_states2_padded = hidden_states2_padded.reshape(-1, hidden_states.shape[1], hidden_states.shape[2])
+ return hidden_states2_padded
+
+ if self.crosspond_chain_pos == "before":
+ hidden_states = hidden_states + do_paritial_crosspond(hidden_states, ref_state)
+
+ hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs) # [B * modalities * views, L, D]
+ if self.crosspond_chain_pos == "parralle":
+ hidden_states1 = hidden_states1 + do_paritial_crosspond(hidden_states, ref_state)
+ if self.crosspond_chain_pos == "after":
+ hidden_states1 = hidden_states1 + do_paritial_crosspond(hidden_states1, ref_state)
+ return hidden_states1
+ elif self.simple_3d:
+ B, L, C = encoder_hidden_states.shape
+ mv = self.views
+ encoder_hidden_states = encoder_hidden_states.reshape(B // mv, mv, L, C)
+ ref_state = ref_state[:, None]
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ref_state], dim=1)
+ encoder_hidden_states = encoder_hidden_states.reshape(B // mv, 1, (mv+1) * L, C)
+ encoder_hidden_states = encoder_hidden_states.repeat(1, mv, 1, 1).reshape(-1, (mv+1) * L, C)
+ return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
+ else:
+ ref_state = ref_state[:, None].expand(-1, modalities * views, -1, -1).reshape(-1, ref_state.shape[-2], ref_state.shape[-1])
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ref_state], dim=1)
+ return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
+ else:
+ raise NotImplementedError("mode or self.mode is required to be 'extract' or 'inject'")
+
+def add_extra_processor(model: torch.nn.Module, enable_filter=lambda x:True, **kwargs):
+ return_dict = torch.nn.ModuleDict()
+ proj_in_dim = kwargs.get('proj_in_dim', False)
+ kwargs.pop('proj_in_dim', None)
+
+ def recursive_add_processors(name: str, module: torch.nn.Module):
+ for sub_name, child in module.named_children():
+ if "ref_unet" not in (sub_name + name):
+ recursive_add_processors(f"{name}.{sub_name}", child)
+
+ if isinstance(module, Attention):
+ new_processor = ExtraAttnProc(
+ chained_proc=module.get_processor(),
+ enabled=enable_filter(f"{name}.processor"),
+ name=f"{name}.processor",
+ proj_in_dim=proj_in_dim if proj_in_dim else module.cross_attention_dim,
+ target_dim=module.cross_attention_dim,
+ **kwargs
+ )
+ module.set_processor(new_processor)
+ return_dict[f"{name}.processor".replace(".", "__")] = new_processor
+
+ for name, module in model.named_children():
+ recursive_add_processors(name, module)
+ return return_dict
+
+def switch_extra_processor(model, enable_filter=lambda x:True):
+ def recursive_add_processors(name: str, module: torch.nn.Module):
+ for sub_name, child in module.named_children():
+ recursive_add_processors(f"{name}.{sub_name}", child)
+
+ if isinstance(module, ExtraAttnProc):
+ module.enabled = enable_filter(name)
+
+ for name, module in model.named_children():
+ recursive_add_processors(name, module)
+
+class multiviewAttnProc(torch.nn.Module):
+ def __init__(
+ self,
+ chained_proc,
+ enabled=False,
+ name=None,
+ hidden_states_dim=None,
+ chain_pos="parralle", # before or parralle or after
+ num_modalities=1,
+ views=4,
+ base_img_size=64,
+ ) -> None:
+ super().__init__()
+ self.enabled = enabled
+ self.chained_proc = chained_proc
+ self.name = name
+ self.hidden_states_dim = hidden_states_dim
+ self.num_modalities = num_modalities
+ self.views = views
+ self.base_img_size = base_img_size
+ self.chain_pos = chain_pos
+ self.diff_joint_attn = True
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ **kwargs
+ ) -> torch.Tensor:
+ if not self.enabled:
+ return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
+
+ B, L, C = hidden_states.shape
+ mv = self.views
+ hidden_states = hidden_states.reshape(B // mv, mv, L, C).reshape(-1, mv * L, C)
+ hidden_states = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
+ return hidden_states.reshape(B // mv, mv, L, C).reshape(-1, L, C)
+
+def add_multiview_processor(model: torch.nn.Module, enable_filter=lambda x:True, **kwargs):
+ return_dict = torch.nn.ModuleDict()
+ def recursive_add_processors(name: str, module: torch.nn.Module):
+ for sub_name, child in module.named_children():
+ if "ref_unet" not in (sub_name + name):
+ recursive_add_processors(f"{name}.{sub_name}", child)
+
+ if isinstance(module, Attention):
+ new_processor = multiviewAttnProc(
+ chained_proc=module.get_processor(),
+ enabled=enable_filter(f"{name}.processor"),
+ name=f"{name}.processor",
+ hidden_states_dim=module.inner_dim,
+ **kwargs
+ )
+ module.set_processor(new_processor)
+ return_dict[f"{name}.processor".replace(".", "__")] = new_processor
+
+ for name, module in model.named_children():
+ recursive_add_processors(name, module)
+
+ return return_dict
+
+def switch_multiview_processor(model, enable_filter=lambda x:True):
+ def recursive_add_processors(name: str, module: torch.nn.Module):
+ for sub_name, child in module.named_children():
+ recursive_add_processors(f"{name}.{sub_name}", child)
+
+ if isinstance(module, Attention):
+ processor = module.get_processor()
+ if isinstance(processor, multiviewAttnProc):
+ processor.enabled = enable_filter(f"{name}.processor")
+
+ for name, module in model.named_children():
+ recursive_add_processors(name, module)
+
+class NNModuleWrapper(torch.nn.Module):
+ def __init__(self, module):
+ super().__init__()
+ self.module = module
+
+ def forward(self, *args, **kwargs):
+ return self.module(*args, **kwargs)
+
+ def __getattr__(self, name: str):
+ try:
+ return super().__getattr__(name)
+ except AttributeError:
+ return getattr(self.module, name)
+
+class AttnProcessorSwitch(torch.nn.Module):
+ def __init__(
+ self,
+ proc_dict: dict,
+ enabled_proc="default",
+ name=None,
+ switch_name="default_switch",
+ ):
+ super().__init__()
+ self.proc_dict = torch.nn.ModuleDict({k: (v if isinstance(v, torch.nn.Module) else NNModuleWrapper(v)) for k, v in proc_dict.items()})
+ self.enabled_proc = enabled_proc
+ self.name = name
+ self.switch_name = switch_name
+ self.choose_module(enabled_proc)
+
+ def choose_module(self, enabled_proc):
+ self.enabled_proc = enabled_proc
+ assert enabled_proc in self.proc_dict.keys()
+
+ def __call__(
+ self,
+ *args,
+ **kwargs
+ ) -> torch.FloatTensor:
+ used_proc = self.proc_dict[self.enabled_proc]
+ return used_proc(*args, **kwargs)
+
+def add_switch(model: torch.nn.Module, module_filter=lambda x:True, switch_dict_fn=lambda x: {"default": x}, switch_name="default_switch", enabled_proc="default"):
+ return_dict = torch.nn.ModuleDict()
+ def recursive_add_processors(name: str, module: torch.nn.Module):
+ for sub_name, child in module.named_children():
+ if "ref_unet" not in (sub_name + name):
+ recursive_add_processors(f"{name}.{sub_name}", child)
+
+ if isinstance(module, Attention):
+ processor = module.get_processor()
+ if module_filter(processor):
+ proc_dict = switch_dict_fn(processor)
+ new_processor = AttnProcessorSwitch(
+ proc_dict=proc_dict,
+ enabled_proc=enabled_proc,
+ name=f"{name}.processor",
+ switch_name=switch_name,
+ )
+ module.set_processor(new_processor)
+ return_dict[f"{name}.processor".replace(".", "__")] = new_processor
+
+ for name, module in model.named_children():
+ recursive_add_processors(name, module)
+
+ return return_dict
+
+def change_switch(model: torch.nn.Module, switch_name="default_switch", enabled_proc="default"):
+ def recursive_change_processors(name: str, module: torch.nn.Module):
+ for sub_name, child in module.named_children():
+ recursive_change_processors(f"{name}.{sub_name}", child)
+
+ if isinstance(module, Attention):
+ processor = module.get_processor()
+ if isinstance(processor, AttnProcessorSwitch) and processor.switch_name == switch_name:
+ processor.choose_module(enabled_proc)
+
+ for name, module in model.named_children():
+ recursive_change_processors(name, module)
+
+########## Hack: Attention fix #############
+from diffusers.models.attention import Attention
+
+def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ **cross_attention_kwargs,
+) -> torch.Tensor:
+ r"""
+ The forward method of the `Attention` class.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ The hidden states of the query.
+ encoder_hidden_states (`torch.Tensor`, *optional*):
+ The hidden states of the encoder.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention mask to use. If `None`, no mask is applied.
+ **cross_attention_kwargs:
+ Additional keyword arguments to pass along to the cross attention.
+
+ Returns:
+ `torch.Tensor`: The output of the attention layer.
+ """
+ # The `Attention` class can call different attention processors / attention functions
+ # here we simply pass along all tensors to the selected processor class
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+Attention.forward = forward
\ No newline at end of file
diff --git a/custum_3d_diffusion/custum_modules/unifield_processor.py b/custum_3d_diffusion/custum_modules/unifield_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f880932d9b2ec5ff789e2cb6cfd6c5c33942993
--- /dev/null
+++ b/custum_3d_diffusion/custum_modules/unifield_processor.py
@@ -0,0 +1,459 @@
+from types import FunctionType
+from typing import Any, Dict, List
+from diffusers import UNet2DConditionModel
+import torch
+from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, ImageProjection
+from diffusers.models.attention_processor import Attention, AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor
+from dataclasses import dataclass, field
+from diffusers.loaders import IPAdapterMixin
+from custum_3d_diffusion.custum_modules.attention_processors import add_extra_processor, switch_extra_processor, add_multiview_processor, switch_multiview_processor, add_switch, change_switch
+
+@dataclass
+class AttnConfig:
+ """
+ * CrossAttention: Attention module (inherits knowledge), LoRA module (achieves fine-tuning), IPAdapter module (achieves conceptual control).
+ * SelfAttention: Attention module (inherits knowledge), LoRA module (achieves fine-tuning), Reference Attention module (achieves pixel-level control).
+ * Multiview Attention module: Multiview Attention module (achieves multi-view consistency).
+ * Cross Modality Attention module: Cross Modality Attention module (achieves multi-modality consistency).
+
+ For setups:
+ train_xxx_lr is implemented in the U-Net architecture.
+ enable_xxx_lora is implemented in the U-Net architecture.
+ enable_xxx_ip is implemented in the processor and U-Net architecture.
+ enable_xxx_ref_proj_in is implemented in the processor.
+ """
+ latent_size: int = 64
+
+ train_lr: float = 0
+ # for cross attention
+ # 0 learning rate for not training
+ train_cross_attn_lr: float = 0
+ train_cross_attn_lora_lr: float = 0
+ train_cross_attn_ip_lr: float = 0 # 0 for not trained
+ init_cross_attn_lora: bool = False
+ enable_cross_attn_lora: bool = False
+ init_cross_attn_ip: bool = False
+ enable_cross_attn_ip: bool = False
+ cross_attn_lora_rank: int = 64 # 0 for not enabled
+ cross_attn_lora_only_kv: bool = False
+ ipadapter_pretrained_name: str = "h94/IP-Adapter"
+ ipadapter_subfolder_name: str = "models"
+ ipadapter_weight_name: str = "ip-adapter-plus_sd15.safetensors"
+ ipadapter_effect_on: str = "all" # all, first
+
+ # for self attention
+ train_self_attn_lr: float = 0
+ train_self_attn_lora_lr: float = 0
+ init_self_attn_lora: bool = False
+ enable_self_attn_lora: bool = False
+ self_attn_lora_rank: int = 64
+ self_attn_lora_only_kv: bool = False
+
+ train_self_attn_ref_lr: float = 0
+ train_ref_unet_lr: float = 0
+ init_self_attn_ref: bool = False
+ enable_self_attn_ref: bool = False
+ self_attn_ref_other_model_name: str = ""
+ self_attn_ref_position: str = "attn1"
+ self_attn_ref_pixel_wise_crosspond: bool = False # enable pixel_wise_crosspond in refattn
+ self_attn_ref_chain_pos: str = "parralle" # before or parralle or after
+ self_attn_ref_effect_on: str = "all" # all or first, for _crosspond attn
+ self_attn_ref_zero_init: bool = True
+ use_simple3d_attn: bool = False
+
+ # for multiview attention
+ init_multiview_attn: bool = False
+ enable_multiview_attn: bool = False
+ multiview_attn_position: str = "attn1"
+ multiview_chain_pose: str = "parralle" # before or parralle or after
+ num_modalities: int = 1
+ use_mv_joint_attn: bool = False
+
+ # for unet
+ init_unet_path: str = "runwayml/stable-diffusion-v1-5"
+ init_num_cls_label: int = 0 # for initialize
+ cls_labels: List[int] = field(default_factory=lambda: [])
+ cls_label_type: str = "embedding"
+ cat_condition: bool = False # cat condition to input
+
+class Configurable:
+ attn_config: AttnConfig
+
+ def set_config(self, attn_config: AttnConfig):
+ raise NotImplementedError()
+
+ def update_config(self, attn_config: AttnConfig):
+ self.attn_config = attn_config
+
+ def do_set_config(self, attn_config: AttnConfig):
+ self.set_config(attn_config)
+ for name, module in self.named_modules():
+ if isinstance(module, Configurable):
+ if hasattr(module, "do_set_config"):
+ module.do_set_config(attn_config)
+ else:
+ print(f"Warning: {name} has no attribute do_set_config, but is an instance of Configurable")
+ module.attn_config = attn_config
+
+ def do_update_config(self, attn_config: AttnConfig):
+ self.update_config(attn_config)
+ for name, module in self.named_modules():
+ if isinstance(module, Configurable):
+ if hasattr(module, "do_update_config"):
+ module.do_update_config(attn_config)
+ else:
+ print(f"Warning: {name} has no attribute do_update_config, but is an instance of Configurable")
+ module.attn_config = attn_config
+
+from diffusers import ModelMixin # Must import ModelMixin for CompiledUNet
+class UnifieldWrappedUNet(UNet2DConditionModel):
+ forward_hook: FunctionType
+
+ def forward(self, *args, **kwargs):
+ if hasattr(self, 'forward_hook'):
+ return self.forward_hook(super().forward, *args, **kwargs)
+ return super().forward(*args, **kwargs)
+
+
+class ConfigurableUNet2DConditionModel(Configurable, IPAdapterMixin):
+ unet: UNet2DConditionModel
+
+ cls_embedding_param_dict = {}
+ cross_attn_lora_param_dict = {}
+ self_attn_lora_param_dict = {}
+ cross_attn_param_dict = {}
+ self_attn_param_dict = {}
+ ipadapter_param_dict = {}
+ ref_attn_param_dict = {}
+ ref_unet_param_dict = {}
+ multiview_attn_param_dict = {}
+ other_param_dict = {}
+
+ rev_param_name_mapping = {}
+
+ class_labels = []
+ def set_class_labels(self, class_labels: torch.Tensor):
+ if self.attn_config.init_num_cls_label != 0:
+ self.class_labels = class_labels.to(self.unet.device).long()
+
+ def __init__(self, init_config: AttnConfig, weight_dtype) -> None:
+ super().__init__()
+ self.weight_dtype = weight_dtype
+ self.set_config(init_config)
+
+ def enable_xformers_memory_efficient_attention(self):
+ self.unet.enable_xformers_memory_efficient_attention
+ def recursive_add_processors(name: str, module: torch.nn.Module):
+ for sub_name, child in module.named_children():
+ recursive_add_processors(f"{name}.{sub_name}", child)
+
+ if isinstance(module, Attention):
+ if hasattr(module, 'xformers_not_supported'):
+ return
+ old_processor = module.get_processor()
+ if isinstance(old_processor, (AttnProcessor, AttnProcessor2_0)):
+ module.set_use_memory_efficient_attention_xformers(True)
+
+ for name, module in self.unet.named_children():
+ recursive_add_processors(name, module)
+
+ def __getattr__(self, name: str) -> Any:
+ try:
+ return super().__getattr__(name)
+ except AttributeError:
+ return getattr(self.unet, name)
+
+ # --- for IPAdapterMixin
+
+ def register_modules(self, **kwargs):
+ for name, module in kwargs.items():
+ # set models
+ setattr(self, name, module)
+
+ def register_to_config(self, **kwargs):
+ pass
+
+ def unload_ip_adapter(self):
+ raise NotImplementedError()
+
+ # --- for Configurable
+
+ def get_refunet(self):
+ if self.attn_config.self_attn_ref_other_model_name == "self":
+ return self.unet
+ else:
+ return self.unet.ref_unet
+
+ def set_config(self, attn_config: AttnConfig):
+ self.attn_config = attn_config
+
+ unet_type = UnifieldWrappedUNet
+ # class_embed_type = "projection" for 'camera'
+ # class_embed_type = None for 'embedding'
+ unet_kwargs = {}
+ if attn_config.init_num_cls_label > 0:
+ if attn_config.cls_label_type == "embedding":
+ unet_kwargs = {
+ "num_class_embeds": attn_config.init_num_cls_label,
+ "device_map": None,
+ "low_cpu_mem_usage": False,
+ "class_embed_type": None,
+ }
+ else:
+ raise ValueError(f"cls_label_type {attn_config.cls_label_type} is not supported")
+
+ self.unet: UnifieldWrappedUNet = unet_type.from_pretrained(
+ attn_config.init_unet_path, subfolder="unet", torch_dtype=self.weight_dtype,
+ **unet_kwargs
+ )
+ assert isinstance(self.unet, UnifieldWrappedUNet)
+ self.unet.forward_hook = self.unet_forward_hook
+
+ if self.attn_config.cat_condition:
+ # double in_channels
+ if self.unet.config.in_channels != 8:
+ self.unet.register_to_config(in_channels=self.unet.config.in_channels * 2)
+ # repeate unet.conv_in weight twice
+ doubled_conv_in = torch.nn.Conv2d(self.unet.conv_in.in_channels * 2, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding)
+ doubled_conv_in.weight.data = torch.cat([self.unet.conv_in.weight.data, torch.zeros_like(self.unet.conv_in.weight.data)], dim=1)
+ doubled_conv_in.bias.data = self.unet.conv_in.bias.data
+ self.unet.conv_in = doubled_conv_in
+
+ used_param_ids = set()
+
+ if attn_config.init_cross_attn_lora:
+ # setup lora
+ from peft import LoraConfig
+ from peft.utils import get_peft_model_state_dict
+ if attn_config.cross_attn_lora_only_kv:
+ target_modules=["attn2.to_k", "attn2.to_v"]
+ else:
+ target_modules=["attn2.to_k", "attn2.to_q", "attn2.to_v", "attn2.to_out.0"]
+ lora_config: LoraConfig = LoraConfig(
+ r=attn_config.cross_attn_lora_rank,
+ lora_alpha=attn_config.cross_attn_lora_rank,
+ init_lora_weights="gaussian",
+ target_modules=target_modules,
+ )
+ adapter_name="cross_attn_lora"
+ self.unet.add_adapter(lora_config, adapter_name=adapter_name)
+ # update cross_attn_lora_param_dict
+ self.cross_attn_lora_param_dict = {id(param): param for name, param in self.unet.named_parameters() if adapter_name in name and id(param) not in used_param_ids}
+ used_param_ids.update(self.cross_attn_lora_param_dict.keys())
+
+ if attn_config.init_self_attn_lora:
+ # setup lora
+ from peft import LoraConfig
+ if attn_config.self_attn_lora_only_kv:
+ target_modules=["attn1.to_k", "attn1.to_v"]
+ else:
+ target_modules=["attn1.to_k", "attn1.to_q", "attn1.to_v", "attn1.to_out.0"]
+ lora_config: LoraConfig = LoraConfig(
+ r=attn_config.self_attn_lora_rank,
+ lora_alpha=attn_config.self_attn_lora_rank,
+ init_lora_weights="gaussian",
+ target_modules=target_modules,
+ )
+ adapter_name="self_attn_lora"
+ self.unet.add_adapter(lora_config, adapter_name=adapter_name)
+ # update cross_self_lora_param_dict
+ self.self_attn_lora_param_dict = {id(param): param for name, param in self.unet.named_parameters() if adapter_name in name and id(param) not in used_param_ids}
+ used_param_ids.update(self.self_attn_lora_param_dict.keys())
+
+ if attn_config.init_num_cls_label != 0:
+ self.cls_embedding_param_dict = {id(param): param for param in self.unet.class_embedding.parameters()}
+ used_param_ids.update(self.cls_embedding_param_dict.keys())
+ self.set_class_labels(torch.tensor(attn_config.cls_labels).long())
+
+ if attn_config.init_cross_attn_ip:
+ self.image_encoder = None
+ # setup ipadapter
+ self.load_ip_adapter(
+ attn_config.ipadapter_pretrained_name,
+ subfolder=attn_config.ipadapter_subfolder_name,
+ weight_name=attn_config.ipadapter_weight_name
+ )
+ # warp ip_adapter_attn_proc with switch
+ from diffusers.models.attention_processor import IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0
+ add_switch(self.unet, module_filter=lambda x: isinstance(x, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)), switch_dict_fn=lambda x: {"ipadapter": x, "default": XFormersAttnProcessor()}, switch_name="ipadapter_switch", enabled_proc="ipadapter")
+ # update ipadapter_param_dict
+ # weights are in attention processors and unet.encoder_hid_proj
+ self.ipadapter_param_dict = {id(param): param for param in self.unet.encoder_hid_proj.parameters() if id(param) not in used_param_ids}
+ used_param_ids.update(self.ipadapter_param_dict.keys())
+ print("DEBUG: ipadapter_param_dict len in encoder_hid_proj", len(self.ipadapter_param_dict))
+ for name, processor in self.unet.attn_processors.items():
+ if hasattr(processor, "to_k_ip"):
+ self.ipadapter_param_dict.update({id(param): param for param in processor.parameters()})
+ print(f"DEBUG: ipadapter_param_dict len in all", len(self.ipadapter_param_dict))
+
+ ref_unet = None
+ if attn_config.init_self_attn_ref:
+ # setup reference attention processor
+ if attn_config.self_attn_ref_other_model_name == "self":
+ raise NotImplementedError("self reference is not fully implemented")
+ else:
+ ref_unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
+ attn_config.self_attn_ref_other_model_name, subfolder="unet", torch_dtype=self.unet.dtype
+ )
+ ref_unet.to(self.unet.device)
+ if self.attn_config.train_ref_unet_lr == 0:
+ ref_unet.eval()
+ ref_unet.requires_grad_(False)
+ else:
+ ref_unet.train()
+
+ add_extra_processor(
+ model=ref_unet,
+ enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"),
+ mode='extract',
+ with_proj_in=False,
+ pixel_wise_crosspond=False,
+ )
+ # NOTE: Here require cross_attention_dim in two unet's self attention should be the same
+ processor_dict = add_extra_processor(
+ model=self.unet,
+ enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"),
+ mode='inject',
+ with_proj_in=False,
+ pixel_wise_crosspond=attn_config.self_attn_ref_pixel_wise_crosspond,
+ crosspond_effect_on=attn_config.self_attn_ref_effect_on,
+ crosspond_chain_pos=attn_config.self_attn_ref_chain_pos,
+ simple_3d=attn_config.use_simple3d_attn,
+ )
+ self.ref_unet_param_dict = {id(param): param for name, param in ref_unet.named_parameters() if id(param) not in used_param_ids and (attn_config.self_attn_ref_position in name)}
+ if attn_config.self_attn_ref_chain_pos != "after":
+ # pop untrainable paramters
+ for name, param in ref_unet.named_parameters():
+ if id(param) in self.ref_unet_param_dict and ('up_blocks.3.attentions.2.transformer_blocks.0.' in name):
+ self.ref_unet_param_dict.pop(id(param))
+ used_param_ids.update(self.ref_unet_param_dict.keys())
+ # update ref_attn_param_dict
+ self.ref_attn_param_dict = {id(param): param for name, param in processor_dict.named_parameters() if id(param) not in used_param_ids}
+ used_param_ids.update(self.ref_attn_param_dict.keys())
+
+ if attn_config.init_multiview_attn:
+ processor_dict = add_multiview_processor(
+ model = self.unet,
+ enable_filter = lambda name: name.endswith(f"{attn_config.multiview_attn_position}.processor"),
+ num_modalities = attn_config.num_modalities,
+ base_img_size = attn_config.latent_size,
+ chain_pos = attn_config.multiview_chain_pose,
+ )
+ # update multiview_attn_param_dict
+ self.multiview_attn_param_dict = {id(param): param for name, param in processor_dict.named_parameters() if id(param) not in used_param_ids}
+ used_param_ids.update(self.multiview_attn_param_dict.keys())
+
+ # initialize cross_attn_param_dict parameters
+ self.cross_attn_param_dict = {id(param): param for name, param in self.unet.named_parameters() if "attn2" in name and id(param) not in used_param_ids}
+ used_param_ids.update(self.cross_attn_param_dict.keys())
+
+ # initialize self_attn_param_dict parameters
+ self.self_attn_param_dict = {id(param): param for name, param in self.unet.named_parameters() if "attn1" in name and id(param) not in used_param_ids}
+ used_param_ids.update(self.self_attn_param_dict.keys())
+
+ # initialize other_param_dict parameters
+ self.other_param_dict = {id(param): param for name, param in self.unet.named_parameters() if id(param) not in used_param_ids}
+
+ if ref_unet is not None:
+ self.unet.ref_unet = ref_unet
+
+ self.rev_param_name_mapping = {id(param): name for name, param in self.unet.named_parameters()}
+
+ self.update_config(attn_config, force_update=True)
+ return self.unet
+
+ _attn_keys_to_update = ["enable_cross_attn_lora", "enable_cross_attn_ip", "enable_self_attn_lora", "enable_self_attn_ref", "enable_multiview_attn", "cls_labels"]
+
+ def update_config(self, attn_config: AttnConfig, force_update=False):
+ assert isinstance(self.unet, UNet2DConditionModel), "unet must be an instance of UNet2DConditionModel"
+
+ need_to_update = False
+ # update cls_labels
+ for key in self._attn_keys_to_update:
+ if getattr(self.attn_config, key) != getattr(attn_config, key):
+ need_to_update = True
+ break
+ if not force_update and not need_to_update:
+ return
+
+ self.set_class_labels(torch.tensor(attn_config.cls_labels).long())
+
+ # setup loras
+ if self.attn_config.init_cross_attn_lora or self.attn_config.init_self_attn_lora:
+ if attn_config.enable_cross_attn_lora or attn_config.enable_self_attn_lora:
+ cross_attn_lora_weight = 1. if attn_config.enable_cross_attn_lora > 0 else 0
+ self_attn_lora_weight = 1. if attn_config.enable_self_attn_lora > 0 else 0
+ self.unet.set_adapters(["cross_attn_lora", "self_attn_lora"], weights=[cross_attn_lora_weight, self_attn_lora_weight])
+ else:
+ self.unet.disable_adapters()
+
+ # setup ipadapter
+ if self.attn_config.init_cross_attn_ip:
+ if attn_config.enable_cross_attn_ip:
+ change_switch(self.unet, "ipadapter_switch", "ipadapter")
+ else:
+ change_switch(self.unet, "ipadapter_switch", "default")
+
+ # setup reference attention processor
+ if self.attn_config.init_self_attn_ref:
+ if attn_config.enable_self_attn_ref:
+ switch_extra_processor(self.unet, enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"))
+ else:
+ switch_extra_processor(self.unet, enable_filter=lambda name: False)
+
+ # setup multiview attention processor
+ if self.attn_config.init_multiview_attn:
+ if attn_config.enable_multiview_attn:
+ switch_multiview_processor(self.unet, enable_filter=lambda name: name.endswith(f"{attn_config.multiview_attn_position}.processor"))
+ else:
+ switch_multiview_processor(self.unet, enable_filter=lambda name: False)
+
+ # update cls_labels
+ for key in self._attn_keys_to_update:
+ setattr(self.attn_config, key, getattr(attn_config, key))
+
+ def unet_forward_hook(self, raw_forward, sample: torch.FloatTensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, *args, cross_attention_kwargs=None, condition_latents=None, class_labels=None, noisy_condition_input=False, cond_pixels_clip=None, **kwargs):
+ if class_labels is None and len(self.class_labels) > 0:
+ class_labels = self.class_labels.repeat(sample.shape[0] // self.class_labels.shape[0]).to(sample.device)
+ elif self.attn_config.init_num_cls_label != 0:
+ assert class_labels is not None, "class_labels should be passed if self.class_labels is empty and self.attn_config.init_num_cls_label is not 0"
+ if class_labels is not None:
+ if self.attn_config.cls_label_type == "embedding":
+ pass
+ else:
+ raise ValueError(f"cls_label_type {self.attn_config.cls_label_type} is not supported")
+ if self.attn_config.init_self_attn_ref and self.attn_config.enable_self_attn_ref:
+ # NOTE: extra step, extract condition
+ ref_dict = {}
+ ref_unet = self.get_refunet().to(sample.device)
+ assert condition_latents is not None
+ if self.attn_config.self_attn_ref_other_model_name == "self":
+ raise NotImplementedError()
+ else:
+ with torch.no_grad():
+ cond_encoder_hidden_states = encoder_hidden_states.reshape(condition_latents.shape[0], -1, *encoder_hidden_states.shape[1:])[:, 0]
+ if timestep.dim() == 0:
+ cond_timestep = timestep
+ else:
+ cond_timestep = timestep.reshape(condition_latents.shape[0], -1)[:, 0]
+ ref_unet(condition_latents, cond_timestep, cond_encoder_hidden_states, cross_attention_kwargs=dict(ref_dict=ref_dict))
+ # NOTE: extra step, inject condition
+ # Predict the noise residual and compute loss
+ if cross_attention_kwargs is None:
+ cross_attention_kwargs = {}
+ cross_attention_kwargs.update(ref_dict=ref_dict, mode='inject')
+ elif condition_latents is not None:
+ if not hasattr(self, 'condition_latents_raised'):
+ print("Warning! condition_latents is not None, but self_attn_ref is not enabled! This warning will only be raised once.")
+ self.condition_latents_raised = True
+
+ if self.attn_config.init_cross_attn_ip:
+ raise NotImplementedError()
+
+ if self.attn_config.cat_condition:
+ assert condition_latents is not None
+ B = condition_latents.shape[0]
+ cat_latents = condition_latents.reshape(B, 1, *condition_latents.shape[1:]).repeat(1, sample.shape[0] // B, 1, 1, 1).reshape(*sample.shape)
+ sample = torch.cat([sample, cat_latents], dim=1)
+
+ return raw_forward(sample, timestep, encoder_hidden_states, *args, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, **kwargs)
diff --git a/custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2img.py b/custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e6ef50a3ddd5c2a7813916bbbabb31c5886f01f
--- /dev/null
+++ b/custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2img.py
@@ -0,0 +1,298 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# modified by Wuvin
+
+
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
+from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
+from PIL import Image
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+
+
+class StableDiffusionImageCustomPipeline(
+ StableDiffusionImageVariationPipeline
+):
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ image_encoder: CLIPVisionModelWithProjection,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ latents_offset=None,
+ noisy_cond_latents=False,
+ ):
+ super().__init__(
+ vae=vae,
+ image_encoder=image_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ requires_safety_checker=requires_safety_checker
+ )
+ latents_offset = tuple(latents_offset) if latents_offset is not None else None
+ self.latents_offset = latents_offset
+ if latents_offset is not None:
+ self.register_to_config(latents_offset=latents_offset)
+ self.noisy_cond_latents = noisy_cond_latents
+ self.register_to_config(noisy_cond_latents=noisy_cond_latents)
+
+ def encode_latents(self, image, device, dtype, height, width):
+ # support batchsize > 1
+ if isinstance(image, Image.Image):
+ image = [image]
+ image = [img.convert("RGB") for img in image]
+ images = self.image_processor.preprocess(image, height=height, width=width).to(device, dtype=dtype)
+ latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor
+ if self.latents_offset is not None:
+ return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
+ else:
+ return latents
+
+ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeddings = self.image_encoder(image).image_embeds
+ image_embeddings = image_embeddings.unsqueeze(1)
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # NOTE: the same as original code
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
+
+ return image_embeddings
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[Image.Image, List[Image.Image], torch.FloatTensor],
+ height: Optional[int] = 1024,
+ width: Optional[int] = 1024,
+ height_cond: Optional[int] = 512,
+ width_cond: Optional[int] = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ upper_left_feature: bool = False,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`):
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter is modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+
+ Examples:
+
+ ```py
+ from diffusers import StableDiffusionImageVariationPipeline
+ from PIL import Image
+ from io import BytesIO
+ import requests
+
+ pipe = StableDiffusionImageVariationPipeline.from_pretrained(
+ "lambdalabs/sd-image-variations-diffusers", revision="v2.0"
+ )
+ pipe = pipe.to("cuda")
+
+ url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
+
+ response = requests.get(url)
+ image = Image.open(BytesIO(response.content)).convert("RGB")
+
+ out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
+ out["images"][0].save("result.jpg")
+ ```
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(image, height, width, callback_steps)
+
+ # 2. Define call parameters
+ if isinstance(image, Image.Image):
+ batch_size = 1
+ elif isinstance(image, list):
+ batch_size = len(image)
+ else:
+ batch_size = image.shape[0]
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input image
+ if isinstance(image, Image.Image) and upper_left_feature:
+ # only use the first one of four images
+ emb_image = image.crop((0, 0, image.size[0] // 2, image.size[1] // 2))
+ else:
+ emb_image = image
+
+ image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance)
+ cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.out_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ image_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.noisy_cond_latents:
+ raise ValueError("Noisy condition latents is not recommended.")
+ else:
+ noisy_cond_latents = cond_latents
+
+ noisy_cond_latents = torch.cat([torch.zeros_like(noisy_cond_latents), noisy_cond_latents]) if do_classifier_free_guidance else noisy_cond_latents
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=noisy_cond_latents).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ self.maybe_free_model_hooks()
+
+ if self.latents_offset is not None:
+ latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+if __name__ == "__main__":
+ pass
diff --git a/custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2mvimg.py b/custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2mvimg.py
new file mode 100644
index 0000000000000000000000000000000000000000..de342d1b9767b6d1cea138bb24d2d2fff34229fc
--- /dev/null
+++ b/custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2mvimg.py
@@ -0,0 +1,296 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# modified by Wuvin
+
+
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline
+from diffusers.schedulers import KarrasDiffusionSchedulers, DDPMScheduler
+from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
+from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
+from PIL import Image
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+
+
+class StableDiffusionImage2MVCustomPipeline(
+ StableDiffusionImageVariationPipeline
+):
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ image_encoder: CLIPVisionModelWithProjection,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ latents_offset=None,
+ noisy_cond_latents=False,
+ condition_offset=True,
+ ):
+ super().__init__(
+ vae=vae,
+ image_encoder=image_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ requires_safety_checker=requires_safety_checker
+ )
+ latents_offset = tuple(latents_offset) if latents_offset is not None else None
+ self.latents_offset = latents_offset
+ if latents_offset is not None:
+ self.register_to_config(latents_offset=latents_offset)
+ if noisy_cond_latents:
+ raise NotImplementedError("Noisy condition latents not supported Now.")
+ self.condition_offset = condition_offset
+ self.register_to_config(condition_offset=condition_offset)
+
+ def encode_latents(self, image: Image.Image, device, dtype, height, width):
+ images = self.image_processor.preprocess(image.convert("RGB"), height=height, width=width).to(device, dtype=dtype)
+ # NOTE: .mode() for condition
+ latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor
+ if self.latents_offset is not None and self.condition_offset:
+ return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
+ else:
+ return latents
+
+ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeddings = self.image_encoder(image).image_embeds
+ image_embeddings = image_embeddings.unsqueeze(1)
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # NOTE: the same as original code
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
+
+ return image_embeddings
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[Image.Image, List[Image.Image], torch.FloatTensor],
+ height: Optional[int] = 1024,
+ width: Optional[int] = 1024,
+ height_cond: Optional[int] = 512,
+ width_cond: Optional[int] = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`):
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter is modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+
+ Examples:
+
+ ```py
+ from diffusers import StableDiffusionImageVariationPipeline
+ from PIL import Image
+ from io import BytesIO
+ import requests
+
+ pipe = StableDiffusionImageVariationPipeline.from_pretrained(
+ "lambdalabs/sd-image-variations-diffusers", revision="v2.0"
+ )
+ pipe = pipe.to("cuda")
+
+ url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
+
+ response = requests.get(url)
+ image = Image.open(BytesIO(response.content)).convert("RGB")
+
+ out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
+ out["images"][0].save("result.jpg")
+ ```
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(image, height, width, callback_steps)
+
+ # 2. Define call parameters
+ if isinstance(image, Image.Image):
+ batch_size = 1
+ elif len(image) == 1:
+ image = image[0]
+ batch_size = 1
+ else:
+ raise NotImplementedError()
+ # elif isinstance(image, list):
+ # batch_size = len(image)
+ # else:
+ # batch_size = image.shape[0]
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input image
+ emb_image = image
+
+ image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance)
+ cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond)
+ cond_latents = torch.cat([torch.zeros_like(cond_latents), cond_latents]) if do_classifier_free_guidance else cond_latents
+ image_pixels = self.feature_extractor(images=emb_image, return_tensors="pt").pixel_values
+ if do_classifier_free_guidance:
+ image_pixels = torch.cat([torch.zeros_like(image_pixels), image_pixels], dim=0)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.out_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ image_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+
+ # 6. Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=cond_latents, noisy_condition_input=False, cond_pixels_clip=image_pixels).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ self.maybe_free_model_hooks()
+
+ if self.latents_offset is not None:
+ latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+if __name__ == "__main__":
+ pass
diff --git a/custum_3d_diffusion/modules.py b/custum_3d_diffusion/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..85a0af49c8dff30e64f9f5c1ba94ff697702523a
--- /dev/null
+++ b/custum_3d_diffusion/modules.py
@@ -0,0 +1,14 @@
+__modules__ = {}
+
+def register(name):
+ def decorator(cls):
+ __modules__[name] = cls
+ return cls
+
+ return decorator
+
+
+def find(name):
+ return __modules__[name]
+
+from custum_3d_diffusion.trainings import base, image2mvimage_trainer, image2image_trainer
diff --git a/custum_3d_diffusion/trainings/__init__.py b/custum_3d_diffusion/trainings/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/custum_3d_diffusion/trainings/base.py b/custum_3d_diffusion/trainings/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4ce969190b788b34f9c777b30413b5d6d79a349
--- /dev/null
+++ b/custum_3d_diffusion/trainings/base.py
@@ -0,0 +1,208 @@
+import torch
+from accelerate import Accelerator
+from accelerate.logging import MultiProcessAdapter
+from dataclasses import dataclass, field
+from typing import Optional, Union
+from datasets import load_dataset
+import json
+import abc
+from diffusers.utils import make_image_grid
+import numpy as np
+import wandb
+
+from custum_3d_diffusion.trainings.utils import load_config
+from custum_3d_diffusion.custum_modules.unifield_processor import ConfigurableUNet2DConditionModel, AttnConfig
+
+class BasicTrainer(torch.nn.Module, abc.ABC):
+ accelerator: Accelerator
+ logger: MultiProcessAdapter
+ unet: ConfigurableUNet2DConditionModel
+ train_dataloader: torch.utils.data.DataLoader
+ test_dataset: torch.utils.data.Dataset
+ attn_config: AttnConfig
+
+ @dataclass
+ class TrainerConfig:
+ trainer_name: str = "basic"
+ pretrained_model_name_or_path: str = ""
+
+ attn_config: dict = field(default_factory=dict)
+ dataset_name: str = ""
+ dataset_config_name: Optional[str] = None
+ resolution: str = "1024"
+ dataloader_num_workers: int = 4
+ pair_sampler_group_size: int = 1
+ num_views: int = 4
+
+ max_train_steps: int = -1 # -1 means infinity, otherwise [0, max_train_steps)
+ training_step_interval: int = 1 # train on step i*interval, stop at max_train_steps
+ max_train_samples: Optional[int] = None
+ seed: Optional[int] = None # For dataset related operations and validation stuff
+ train_batch_size: int = 1
+
+ validation_interval: int = 5000
+ debug: bool = False
+
+ cfg: TrainerConfig # only enable_xxx is used
+
+ def __init__(
+ self,
+ accelerator: Accelerator,
+ logger: MultiProcessAdapter,
+ unet: ConfigurableUNet2DConditionModel,
+ config: Union[dict, str],
+ weight_dtype: torch.dtype,
+ index: int,
+ ):
+ super().__init__()
+ self.index = index # index in all trainers
+ self.accelerator = accelerator
+ self.logger = logger
+ self.unet = unet
+ self.weight_dtype = weight_dtype
+ self.ext_logs = {}
+ self.cfg = load_config(self.TrainerConfig, config)
+ self.attn_config = load_config(AttnConfig, self.cfg.attn_config)
+ self.test_dataset = None
+ self.validate_trainer_config()
+ self.configure()
+
+ def get_HW(self):
+ resolution = json.loads(self.cfg.resolution)
+ if isinstance(resolution, int):
+ H = W = resolution
+ elif isinstance(resolution, list):
+ H, W = resolution
+ return H, W
+
+ def unet_update(self):
+ self.unet.update_config(self.attn_config)
+
+ def validate_trainer_config(self):
+ pass
+
+ def is_train_finished(self, current_step):
+ assert isinstance(self.cfg.max_train_steps, int)
+ return self.cfg.max_train_steps != -1 and current_step >= self.cfg.max_train_steps
+
+ def next_train_step(self, current_step):
+ if self.is_train_finished(current_step):
+ return None
+ return current_step + self.cfg.training_step_interval
+
+ @classmethod
+ def make_image_into_grid(cls, all_imgs, rows=2, columns=2):
+ catted = [make_image_grid(all_imgs[i:i+rows * columns], rows=rows, cols=columns) for i in range(0, len(all_imgs), rows * columns)]
+ return make_image_grid(catted, rows=1, cols=len(catted))
+
+ def configure(self) -> None:
+ pass
+
+ @abc.abstractmethod
+ def init_shared_modules(self, shared_modules: dict) -> dict:
+ pass
+
+ def load_dataset(self):
+ dataset = load_dataset(
+ self.cfg.dataset_name,
+ self.cfg.dataset_config_name,
+ trust_remote_code=True
+ )
+ return dataset
+
+ @abc.abstractmethod
+ def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
+ """Both init train_dataloader and test_dataset, but returns train_dataloader only"""
+ pass
+
+ @abc.abstractmethod
+ def forward_step(
+ self,
+ *args,
+ **kwargs
+ ) -> torch.Tensor:
+ """
+ input a batch
+ return a loss
+ """
+ self.unet_update()
+ pass
+
+ @abc.abstractmethod
+ def construct_pipeline(self, shared_modules, unet):
+ pass
+
+ @abc.abstractmethod
+ def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
+ """
+ For inference time forward.
+ """
+ pass
+
+ @abc.abstractmethod
+ def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
+ pass
+
+ def do_validation(
+ self,
+ shared_modules,
+ unet,
+ global_step,
+ ):
+ self.unet_update()
+ self.logger.info("Running validation... ")
+ pipeline = self.construct_pipeline(shared_modules, unet)
+ pipeline.set_progress_bar_config(disable=True)
+ titles, images = self.batched_validation_forward(pipeline, guidance_scale=[1., 3.])
+ for tracker in self.accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ [image.thumbnail((512, 512)) for image, title in zip(images, titles) if 'noresize' not in title] # inplace operation
+ tracker.log({"validation": [
+ wandb.Image(image, caption=f"{i}: {titles[i]}", file_type="jpg")
+ for i, image in enumerate(images)]})
+ else:
+ self.logger.warn(f"image logging not implemented for {tracker.name}")
+ del pipeline
+ torch.cuda.empty_cache()
+ return images
+
+
+ @torch.no_grad()
+ def log_validation(
+ self,
+ shared_modules,
+ unet,
+ global_step,
+ force=False
+ ):
+ if self.accelerator.is_main_process:
+ for tracker in self.accelerator.trackers:
+ if tracker.name == "wandb":
+ tracker.log(self.ext_logs)
+ self.ext_logs = {}
+ if (global_step % self.cfg.validation_interval == 0 and not self.is_train_finished(global_step)) or force:
+ self.unet_update()
+ if self.accelerator.is_main_process:
+ self.do_validation(shared_modules, self.accelerator.unwrap_model(unet), global_step)
+
+ def save_model(self, unwrap_unet, shared_modules, save_dir):
+ if self.accelerator.is_main_process:
+ pipeline = self.construct_pipeline(shared_modules, unwrap_unet)
+ pipeline.save_pretrained(save_dir)
+ self.logger.info(f"{self.cfg.trainer_name} Model saved at {save_dir}")
+
+ def save_debug_info(self, save_name="debug", **kwargs):
+ if self.cfg.debug:
+ to_saves = {key: value.detach().cpu() if isinstance(value, torch.Tensor) else value for key, value in kwargs.items()}
+ import pickle
+ import os
+ if os.path.exists(f"{save_name}.pkl"):
+ for i in range(100):
+ if not os.path.exists(f"{save_name}_v{i}.pkl"):
+ save_name = f"{save_name}_v{i}"
+ break
+ with open(f"{save_name}.pkl", "wb") as f:
+ pickle.dump(to_saves, f)
\ No newline at end of file
diff --git a/custum_3d_diffusion/trainings/config_classes.py b/custum_3d_diffusion/trainings/config_classes.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4e54921d1c648a1ac88c109873419b27a43e015
--- /dev/null
+++ b/custum_3d_diffusion/trainings/config_classes.py
@@ -0,0 +1,35 @@
+from dataclasses import dataclass, field
+from typing import List, Optional
+
+
+@dataclass
+class TrainerSubConfig:
+ trainer_type: str = ""
+ trainer: dict = field(default_factory=dict)
+
+
+@dataclass
+class ExprimentConfig:
+ trainers: List[dict] = field(default_factory=lambda: [])
+ init_config: dict = field(default_factory=dict)
+ pretrained_model_name_or_path: str = ""
+ pretrained_unet_state_dict_path: str = ""
+ # expriments related parameters
+ linear_beta_schedule: bool = False
+ zero_snr: bool = False
+ prediction_type: Optional[str] = None
+ seed: Optional[int] = None
+ max_train_steps: int = 1000000
+ gradient_accumulation_steps: int = 1
+ learning_rate: float = 1e-4
+ lr_scheduler: str = "constant"
+ lr_warmup_steps: int = 500
+ use_8bit_adam: bool = False
+ adam_beta1: float = 0.9
+ adam_beta2: float = 0.999
+ adam_weight_decay: float = 1e-2
+ adam_epsilon: float = 1e-08
+ max_grad_norm: float = 1.0
+ mixed_precision: Optional[str] = None # ["no", "fp16", "bf16", "fp8"]
+ skip_training: bool = False
+ debug: bool = False
\ No newline at end of file
diff --git a/custum_3d_diffusion/trainings/image2image_trainer.py b/custum_3d_diffusion/trainings/image2image_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f22c8efefb0f8b8c58cf4a301567919f241222c
--- /dev/null
+++ b/custum_3d_diffusion/trainings/image2image_trainer.py
@@ -0,0 +1,86 @@
+import json
+import torch
+from diffusers import EulerAncestralDiscreteScheduler, DDPMScheduler
+from dataclasses import dataclass
+
+from custum_3d_diffusion.modules import register
+from custum_3d_diffusion.trainings.image2mvimage_trainer import Image2MVImageTrainer
+from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2img import StableDiffusionImageCustomPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+
+def get_HW(resolution):
+ if isinstance(resolution, str):
+ resolution = json.loads(resolution)
+ if isinstance(resolution, int):
+ H = W = resolution
+ elif isinstance(resolution, list):
+ H, W = resolution
+ return H, W
+
+
+@register("image2image_trainer")
+class Image2ImageTrainer(Image2MVImageTrainer):
+ """
+ Trainer for simple image to multiview images.
+ """
+ @dataclass
+ class TrainerConfig(Image2MVImageTrainer.TrainerConfig):
+ trainer_name: str = "image2image"
+
+ cfg: TrainerConfig
+
+ def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor:
+ raise NotImplementedError()
+
+ def construct_pipeline(self, shared_modules, unet, old_version=False):
+ MyPipeline = StableDiffusionImageCustomPipeline
+ pipeline = MyPipeline.from_pretrained(
+ self.cfg.pretrained_model_name_or_path,
+ vae=shared_modules['vae'],
+ image_encoder=shared_modules['image_encoder'],
+ feature_extractor=shared_modules['feature_extractor'],
+ unet=unet,
+ safety_checker=None,
+ torch_dtype=self.weight_dtype,
+ latents_offset=self.cfg.latents_offset,
+ noisy_cond_latents=self.cfg.noisy_condition_input,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+ scheduler_dict = {}
+ if self.cfg.zero_snr:
+ scheduler_dict.update(rescale_betas_zero_snr=True)
+ if self.cfg.linear_beta_schedule:
+ scheduler_dict.update(beta_schedule='linear')
+
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict)
+ return pipeline
+
+ def get_forward_args(self):
+ if self.cfg.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed)
+
+ H, W = get_HW(self.cfg.resolution)
+ H_cond, W_cond = get_HW(self.cfg.condition_image_resolution)
+
+ forward_args = dict(
+ num_images_per_prompt=1,
+ num_inference_steps=20,
+ height=H,
+ width=W,
+ height_cond=H_cond,
+ width_cond=W_cond,
+ generator=generator,
+ )
+ if self.cfg.zero_snr:
+ forward_args.update(guidance_rescale=0.7)
+ return forward_args
+
+ def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput:
+ forward_args = self.get_forward_args()
+ forward_args.update(pipeline_call_kwargs)
+ return pipeline(**forward_args)
+
+ def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
+ raise NotImplementedError()
\ No newline at end of file
diff --git a/custum_3d_diffusion/trainings/image2mvimage_trainer.py b/custum_3d_diffusion/trainings/image2mvimage_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1e6117eed017495e663d9bf6f289238a35c5b88
--- /dev/null
+++ b/custum_3d_diffusion/trainings/image2mvimage_trainer.py
@@ -0,0 +1,139 @@
+import torch
+from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, DDIMScheduler
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, BatchFeature
+
+import json
+from dataclasses import dataclass
+from typing import List, Optional
+
+from custum_3d_diffusion.modules import register
+from custum_3d_diffusion.trainings.base import BasicTrainer
+from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2mvimg import StableDiffusionImage2MVCustomPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+
+def get_HW(resolution):
+ if isinstance(resolution, str):
+ resolution = json.loads(resolution)
+ if isinstance(resolution, int):
+ H = W = resolution
+ elif isinstance(resolution, list):
+ H, W = resolution
+ return H, W
+
+@register("image2mvimage_trainer")
+class Image2MVImageTrainer(BasicTrainer):
+ """
+ Trainer for simple image to multiview images.
+ """
+ @dataclass
+ class TrainerConfig(BasicTrainer.TrainerConfig):
+ trainer_name: str = "image2mvimage"
+ condition_image_column_name: str = "conditioning_image"
+ image_column_name: str = "image"
+ condition_dropout: float = 0.
+ condition_image_resolution: str = "512"
+ validation_images: Optional[List[str]] = None
+ noise_offset: float = 0.1
+ max_loss_drop: float = 0.
+ snr_gamma: float = 5.0
+ log_distribution: bool = False
+ latents_offset: Optional[List[float]] = None
+ input_perturbation: float = 0.
+ noisy_condition_input: bool = False # whether to add noise for ref unet input
+ normal_cls_offset: int = 0
+ condition_offset: bool = True
+ zero_snr: bool = False
+ linear_beta_schedule: bool = False
+
+ cfg: TrainerConfig
+
+ def configure(self) -> None:
+ return super().configure()
+
+ def init_shared_modules(self, shared_modules: dict) -> dict:
+ if 'vae' not in shared_modules:
+ vae = AutoencoderKL.from_pretrained(
+ self.cfg.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.weight_dtype
+ )
+ vae.requires_grad_(False)
+ vae.to(self.accelerator.device, dtype=self.weight_dtype)
+ shared_modules['vae'] = vae
+ if 'image_encoder' not in shared_modules:
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ self.cfg.pretrained_model_name_or_path, subfolder="image_encoder"
+ )
+ image_encoder.requires_grad_(False)
+ image_encoder.to(self.accelerator.device, dtype=self.weight_dtype)
+ shared_modules['image_encoder'] = image_encoder
+ if 'feature_extractor' not in shared_modules:
+ feature_extractor = CLIPImageProcessor.from_pretrained(
+ self.cfg.pretrained_model_name_or_path, subfolder="feature_extractor"
+ )
+ shared_modules['feature_extractor'] = feature_extractor
+ return shared_modules
+
+ def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
+ raise NotImplementedError()
+
+ def loss_rescale(self, loss, timesteps=None):
+ raise NotImplementedError()
+
+ def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor:
+ raise NotImplementedError()
+
+ def construct_pipeline(self, shared_modules, unet, old_version=False):
+ MyPipeline = StableDiffusionImage2MVCustomPipeline
+ pipeline = MyPipeline.from_pretrained(
+ self.cfg.pretrained_model_name_or_path,
+ vae=shared_modules['vae'],
+ image_encoder=shared_modules['image_encoder'],
+ feature_extractor=shared_modules['feature_extractor'],
+ unet=unet,
+ safety_checker=None,
+ torch_dtype=self.weight_dtype,
+ latents_offset=self.cfg.latents_offset,
+ noisy_cond_latents=self.cfg.noisy_condition_input,
+ condition_offset=self.cfg.condition_offset,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+ scheduler_dict = {}
+ if self.cfg.zero_snr:
+ scheduler_dict.update(rescale_betas_zero_snr=True)
+ if self.cfg.linear_beta_schedule:
+ scheduler_dict.update(beta_schedule='linear')
+
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict)
+ return pipeline
+
+ def get_forward_args(self):
+ if self.cfg.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed)
+
+ H, W = get_HW(self.cfg.resolution)
+ H_cond, W_cond = get_HW(self.cfg.condition_image_resolution)
+
+ sub_img_H = H // 2
+ num_imgs = H // sub_img_H * W // sub_img_H
+
+ forward_args = dict(
+ num_images_per_prompt=num_imgs,
+ num_inference_steps=50,
+ height=sub_img_H,
+ width=sub_img_H,
+ height_cond=H_cond,
+ width_cond=W_cond,
+ generator=generator,
+ )
+ if self.cfg.zero_snr:
+ forward_args.update(guidance_rescale=0.7)
+ return forward_args
+
+ def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput:
+ forward_args = self.get_forward_args()
+ forward_args.update(pipeline_call_kwargs)
+ return pipeline(**forward_args)
+
+ def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
+ raise NotImplementedError()
\ No newline at end of file
diff --git a/custum_3d_diffusion/trainings/utils.py b/custum_3d_diffusion/trainings/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..58637e155304f2914d4efca97b792976ac9c86f4
--- /dev/null
+++ b/custum_3d_diffusion/trainings/utils.py
@@ -0,0 +1,25 @@
+from omegaconf import DictConfig, OmegaConf
+
+
+def parse_structured(fields, cfg) -> DictConfig:
+ scfg = OmegaConf.structured(fields(**cfg))
+ return scfg
+
+
+def load_config(fields, config, extras=None):
+ if extras is not None:
+ print("Warning! extra parameter in cli is not verified, may cause erros.")
+ if isinstance(config, str):
+ cfg = OmegaConf.load(config)
+ elif isinstance(config, dict):
+ cfg = OmegaConf.create(config)
+ elif isinstance(config, DictConfig):
+ cfg = config
+ else:
+ raise NotImplementedError(f"Unsupported config type {type(config)}")
+ if extras is not None:
+ cli_conf = OmegaConf.from_cli(extras)
+ cfg = OmegaConf.merge(cfg, cli_conf)
+ OmegaConf.resolve(cfg)
+ assert isinstance(cfg, DictConfig)
+ return parse_structured(fields, cfg)
\ No newline at end of file
diff --git a/gradio_app.py b/gradio_app.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d3fa6c529e330d59f457f5791633f2bd8a91139
--- /dev/null
+++ b/gradio_app.py
@@ -0,0 +1,41 @@
+if __name__ == "__main__":
+ import os
+ import sys
+ sys.path.append(os.curdir)
+ import torch
+ torch.set_float32_matmul_precision('medium')
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.set_grad_enabled(False)
+
+import fire
+import gradio as gr
+from app.gradio_3dgen import create_ui as create_3d_ui
+from app.all_models import model_zoo
+
+
+_TITLE = '''Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image'''
+_DESCRIPTION = '''
+[Project page](https://wukailu.github.io/Unique3D/)
+
+* High-fidelity and diverse textured meshes generated by Unique3D from single-view images.
+
+* The demo is still under construction, and more features are expected to be implemented soon.
+'''
+
+def launch():
+ model_zoo.init_models()
+
+ with gr.Blocks(
+ title=_TITLE,
+ theme=gr.themes.Monochrome(),
+ ) as demo:
+ with gr.Row():
+ with gr.Column(scale=1):
+ gr.Markdown('# ' + _TITLE)
+ gr.Markdown(_DESCRIPTION)
+ create_3d_ui("wkl")
+
+ demo.queue().launch(share=True)
+
+if __name__ == '__main__':
+ fire.Fire(launch)
diff --git a/mesh_reconstruction/func.py b/mesh_reconstruction/func.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f93a8d2ef453216d8edab76280750a91588a371
--- /dev/null
+++ b/mesh_reconstruction/func.py
@@ -0,0 +1,118 @@
+# modified from https://github.com/Profactor/continuous-remeshing
+import torch
+import numpy as np
+import trimesh
+from typing import Tuple
+
+def to_numpy(*args):
+ def convert(a):
+ if isinstance(a,torch.Tensor):
+ return a.detach().cpu().numpy()
+ assert a is None or isinstance(a,np.ndarray)
+ return a
+
+ return convert(args[0]) if len(args)==1 else tuple(convert(a) for a in args)
+
+def laplacian(
+ num_verts:int,
+ edges: torch.Tensor #E,2
+ ) -> torch.Tensor: #sparse V,V
+ """create sparse Laplacian matrix"""
+ V = num_verts
+ E = edges.shape[0]
+
+ #adjacency matrix,
+ idx = torch.cat([edges, edges.fliplr()], dim=0).type(torch.long).T # (2, 2*E)
+ ones = torch.ones(2*E, dtype=torch.float32, device=edges.device)
+ A = torch.sparse.FloatTensor(idx, ones, (V, V))
+
+ #degree matrix
+ deg = torch.sparse.sum(A, dim=1).to_dense()
+ idx = torch.arange(V, device=edges.device)
+ idx = torch.stack([idx, idx], dim=0)
+ D = torch.sparse.FloatTensor(idx, deg, (V, V))
+
+ return D - A
+
+def _translation(x, y, z, device):
+ return torch.tensor([[1., 0, 0, x],
+ [0, 1, 0, y],
+ [0, 0, 1, z],
+ [0, 0, 0, 1]],device=device) #4,4
+
+def _projection(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True):
+ """
+ see https://blog.csdn.net/wodownload2/article/details/85069240/
+ """
+ if l is None:
+ l = -r
+ if t is None:
+ t = r
+ if b is None:
+ b = -t
+ p = torch.zeros([4,4],device=device)
+ p[0,0] = 2*n/(r-l)
+ p[0,2] = (r+l)/(r-l)
+ p[1,1] = 2*n/(t-b) * (-1 if flip_y else 1)
+ p[1,2] = (t+b)/(t-b)
+ p[2,2] = -(f+n)/(f-n)
+ p[2,3] = -(2*f*n)/(f-n)
+ p[3,2] = -1
+ return p #4,4
+
+def _orthographic(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True):
+ if l is None:
+ l = -r
+ if t is None:
+ t = r
+ if b is None:
+ b = -t
+ o = torch.zeros([4,4],device=device)
+ o[0,0] = 2/(r-l)
+ o[0,3] = -(r+l)/(r-l)
+ o[1,1] = 2/(t-b) * (-1 if flip_y else 1)
+ o[1,3] = -(t+b)/(t-b)
+ o[2,2] = -2/(f-n)
+ o[2,3] = -(f+n)/(f-n)
+ o[3,3] = 1
+ return o #4,4
+
+def make_star_cameras(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'):
+ if r is None:
+ r = 1/distance
+ A = az_count
+ P = pol_count
+ C = A * P
+
+ phi = torch.arange(0,A) * (2*torch.pi/A)
+ phi_rot = torch.eye(3,device=device)[None,None].expand(A,1,3,3).clone()
+ phi_rot[:,0,2,2] = phi.cos()
+ phi_rot[:,0,2,0] = -phi.sin()
+ phi_rot[:,0,0,2] = phi.sin()
+ phi_rot[:,0,0,0] = phi.cos()
+
+ theta = torch.arange(1,P+1) * (torch.pi/(P+1)) - torch.pi/2
+ theta_rot = torch.eye(3,device=device)[None,None].expand(1,P,3,3).clone()
+ theta_rot[0,:,1,1] = theta.cos()
+ theta_rot[0,:,1,2] = -theta.sin()
+ theta_rot[0,:,2,1] = theta.sin()
+ theta_rot[0,:,2,2] = theta.cos()
+
+ mv = torch.empty((C,4,4), device=device)
+ mv[:] = torch.eye(4, device=device)
+ mv[:,:3,:3] = (theta_rot @ phi_rot).reshape(C,3,3)
+ mv = _translation(0, 0, -distance, device) @ mv
+
+ return mv, _projection(r,device)
+
+def make_star_cameras_orthographic(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'):
+ mv, _ = make_star_cameras(az_count,pol_count,distance,r,image_size,device)
+ if r is None:
+ r = 1
+ return mv, _orthographic(r,device)
+
+def make_sphere(level:int=2,radius=1.,device='cuda') -> Tuple[torch.Tensor,torch.Tensor]:
+ sphere = trimesh.creation.icosphere(subdivisions=level, radius=1.0, color=None)
+ vertices = torch.tensor(sphere.vertices, device=device, dtype=torch.float32) * radius
+ faces = torch.tensor(sphere.faces, device=device, dtype=torch.long)
+ return vertices,faces
\ No newline at end of file
diff --git a/mesh_reconstruction/opt.py b/mesh_reconstruction/opt.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a51ef20b2cae93f14b1781ba4b91b48ed1ae8d0
--- /dev/null
+++ b/mesh_reconstruction/opt.py
@@ -0,0 +1,190 @@
+# modified from https://github.com/Profactor/continuous-remeshing
+import time
+import torch
+import torch_scatter
+from typing import Tuple
+from mesh_reconstruction.remesh import calc_edge_length, calc_edges, calc_face_collapses, calc_face_normals, calc_vertex_normals, collapse_edges, flip_edges, pack, prepend_dummies, remove_dummies, split_edges
+
+@torch.no_grad()
+def remesh(
+ vertices_etc:torch.Tensor, #V,D
+ faces:torch.Tensor, #F,3 long
+ min_edgelen:torch.Tensor, #V
+ max_edgelen:torch.Tensor, #V
+ flip:bool,
+ max_vertices=1e6
+ ):
+
+ # dummies
+ vertices_etc,faces = prepend_dummies(vertices_etc,faces)
+ vertices = vertices_etc[:,:3] #V,3
+ nan_tensor = torch.tensor([torch.nan],device=min_edgelen.device)
+ min_edgelen = torch.concat((nan_tensor,min_edgelen))
+ max_edgelen = torch.concat((nan_tensor,max_edgelen))
+
+ # collapse
+ edges,face_to_edge = calc_edges(faces) #E,2 F,3
+ edge_length = calc_edge_length(vertices,edges) #E
+ face_normals = calc_face_normals(vertices,faces,normalize=False) #F,3
+ vertex_normals = calc_vertex_normals(vertices,faces,face_normals) #V,3
+ face_collapse = calc_face_collapses(vertices,faces,edges,face_to_edge,edge_length,face_normals,vertex_normals,min_edgelen,area_ratio=0.5)
+ shortness = (1 - edge_length / min_edgelen[edges].mean(dim=-1)).clamp_min_(0) #e[0,1] 0...ok, 1...edgelen=0
+ priority = face_collapse.float() + shortness
+ vertices_etc,faces = collapse_edges(vertices_etc,faces,edges,priority)
+
+ # split
+ if vertices.shape[0] max_edgelen[edges].mean(dim=-1)
+ vertices_etc,faces = split_edges(vertices_etc,faces,edges,face_to_edge,splits,pack_faces=False)
+
+ vertices_etc,faces = pack(vertices_etc,faces)
+ vertices = vertices_etc[:,:3]
+
+ if flip:
+ edges,_,edge_to_face = calc_edges(faces,with_edge_to_face=True) #E,2 F,3
+ flip_edges(vertices,faces,edges,edge_to_face,with_border=False)
+
+ return remove_dummies(vertices_etc,faces)
+
+def lerp_unbiased(a:torch.Tensor,b:torch.Tensor,weight:float,step:int):
+ """lerp with adam's bias correction"""
+ c_prev = 1-weight**(step-1)
+ c = 1-weight**step
+ a_weight = weight*c_prev/c
+ b_weight = (1-weight)/c
+ a.mul_(a_weight).add_(b, alpha=b_weight)
+
+
+class MeshOptimizer:
+ """Use this like a pytorch Optimizer, but after calling opt.step(), do vertices,faces = opt.remesh()."""
+
+ def __init__(self,
+ vertices:torch.Tensor, #V,3
+ faces:torch.Tensor, #F,3
+ lr=0.3, #learning rate
+ betas=(0.8,0.8,0), #betas[0:2] are the same as in Adam, betas[2] may be used to time-smooth the relative velocity nu
+ gammas=(0,0,0), #optional spatial smoothing for m1,m2,nu, values between 0 (no smoothing) and 1 (max. smoothing)
+ nu_ref=0.3, #reference velocity for edge length controller
+ edge_len_lims=(.01,.15), #smallest and largest allowed reference edge length
+ edge_len_tol=.5, #edge length tolerance for split and collapse
+ gain=.2, #gain value for edge length controller
+ laplacian_weight=.02, #for laplacian smoothing/regularization
+ ramp=1, #learning rate ramp, actual ramp width is ramp/(1-betas[0])
+ grad_lim=10., #gradients are clipped to m1.abs()*grad_lim
+ remesh_interval=1, #larger intervals are faster but with worse mesh quality
+ local_edgelen=True, #set to False to use a global scalar reference edge length instead
+ ):
+ self._vertices = vertices
+ self._faces = faces
+ self._lr = lr
+ self._betas = betas
+ self._gammas = gammas
+ self._nu_ref = nu_ref
+ self._edge_len_lims = edge_len_lims
+ self._edge_len_tol = edge_len_tol
+ self._gain = gain
+ self._laplacian_weight = laplacian_weight
+ self._ramp = ramp
+ self._grad_lim = grad_lim
+ self._remesh_interval = remesh_interval
+ self._local_edgelen = local_edgelen
+ self._step = 0
+
+ V = self._vertices.shape[0]
+ # prepare continuous tensor for all vertex-based data
+ self._vertices_etc = torch.zeros([V,9],device=vertices.device)
+ self._split_vertices_etc()
+ self.vertices.copy_(vertices) #initialize vertices
+ self._vertices.requires_grad_()
+ self._ref_len.fill_(edge_len_lims[1])
+
+ @property
+ def vertices(self):
+ return self._vertices
+
+ @property
+ def faces(self):
+ return self._faces
+
+ def _split_vertices_etc(self):
+ self._vertices = self._vertices_etc[:,:3]
+ self._m2 = self._vertices_etc[:,3]
+ self._nu = self._vertices_etc[:,4]
+ self._m1 = self._vertices_etc[:,5:8]
+ self._ref_len = self._vertices_etc[:,8]
+
+ with_gammas = any(g!=0 for g in self._gammas)
+ self._smooth = self._vertices_etc[:,:8] if with_gammas else self._vertices_etc[:,:3]
+
+ def zero_grad(self):
+ self._vertices.grad = None
+
+ @torch.no_grad()
+ def step(self):
+
+ eps = 1e-8
+
+ self._step += 1
+
+ # spatial smoothing
+ edges,_ = calc_edges(self._faces) #E,2
+ E = edges.shape[0]
+ edge_smooth = self._smooth[edges] #E,2,S
+ neighbor_smooth = torch.zeros_like(self._smooth) #V,S
+ torch_scatter.scatter_mean(src=edge_smooth.flip(dims=[1]).reshape(E*2,-1),index=edges.reshape(E*2,1),dim=0,out=neighbor_smooth)
+
+ #apply optional smoothing of m1,m2,nu
+ if self._gammas[0]:
+ self._m1.lerp_(neighbor_smooth[:,5:8],self._gammas[0])
+ if self._gammas[1]:
+ self._m2.lerp_(neighbor_smooth[:,3],self._gammas[1])
+ if self._gammas[2]:
+ self._nu.lerp_(neighbor_smooth[:,4],self._gammas[2])
+
+ #add laplace smoothing to gradients
+ laplace = self._vertices - neighbor_smooth[:,:3]
+ grad = torch.addcmul(self._vertices.grad, laplace, self._nu[:,None], value=self._laplacian_weight)
+
+ #gradient clipping
+ if self._step>1:
+ grad_lim = self._m1.abs().mul_(self._grad_lim)
+ grad.clamp_(min=-grad_lim,max=grad_lim)
+
+ # moment updates
+ lerp_unbiased(self._m1, grad, self._betas[0], self._step)
+ lerp_unbiased(self._m2, (grad**2).sum(dim=-1), self._betas[1], self._step)
+
+ velocity = self._m1 / self._m2[:,None].sqrt().add_(eps) #V,3
+ speed = velocity.norm(dim=-1) #V
+
+ if self._betas[2]:
+ lerp_unbiased(self._nu,speed,self._betas[2],self._step) #V
+ else:
+ self._nu.copy_(speed) #V
+
+ # update vertices
+ ramped_lr = self._lr * min(1,self._step * (1-self._betas[0]) / self._ramp)
+ self._vertices.add_(velocity * self._ref_len[:,None], alpha=-ramped_lr)
+
+ # update target edge length
+ if self._step % self._remesh_interval == 0:
+ if self._local_edgelen:
+ len_change = (1 + (self._nu - self._nu_ref) * self._gain)
+ else:
+ len_change = (1 + (self._nu.mean() - self._nu_ref) * self._gain)
+ self._ref_len *= len_change
+ self._ref_len.clamp_(*self._edge_len_lims)
+
+ def remesh(self, flip:bool=True, poisson=False)->Tuple[torch.Tensor,torch.Tensor]:
+ min_edge_len = self._ref_len * (1 - self._edge_len_tol)
+ max_edge_len = self._ref_len * (1 + self._edge_len_tol)
+
+ self._vertices_etc,self._faces = remesh(self._vertices_etc,self._faces,min_edge_len,max_edge_len,flip, max_vertices=1e6)
+
+ self._split_vertices_etc()
+ self._vertices.requires_grad_()
+
+ return self._vertices, self._faces
diff --git a/mesh_reconstruction/recon.py b/mesh_reconstruction/recon.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cd1e17f37092808c14ed4c7a4ac75de84fde61e
--- /dev/null
+++ b/mesh_reconstruction/recon.py
@@ -0,0 +1,57 @@
+from tqdm import tqdm
+from PIL import Image
+import numpy as np
+import torch
+from typing import List
+from mesh_reconstruction.remesh import calc_vertex_normals
+from mesh_reconstruction.opt import MeshOptimizer
+from mesh_reconstruction.func import make_star_cameras_orthographic
+from mesh_reconstruction.render import NormalsRenderer
+from scripts.utils import to_py3d_mesh, init_target
+
+def reconstruct_stage1(pils: List[Image.Image], steps=100, vertices=None, faces=None, start_edge_len=0.15, end_edge_len=0.005, decay=0.995, return_mesh=True, loss_expansion_weight=0.1, gain=0.1):
+ vertices, faces = vertices.to("cuda"), faces.to("cuda")
+ assert len(pils) == 4
+ mv,proj = make_star_cameras_orthographic(4, 1)
+ renderer = NormalsRenderer(mv,proj,list(pils[0].size))
+
+ target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
+ # 1. no rotate
+ target_images = target_images[[0, 3, 2, 1]]
+
+ # 2. init from coarse mesh
+ opt = MeshOptimizer(vertices,faces, local_edgelen=False, gain=gain, edge_len_lims=(end_edge_len, start_edge_len))
+
+ vertices = opt.vertices
+
+ mask = target_images[..., -1] < 0.5
+
+ for i in tqdm(range(steps)):
+ opt.zero_grad()
+ opt._lr *= decay
+ normals = calc_vertex_normals(vertices,faces)
+ images = renderer.render(vertices,normals,faces)
+
+ loss_expand = 0.5 * ((vertices+normals).detach() - vertices).pow(2).mean()
+
+ t_mask = images[..., -1] > 0.5
+ loss_target_l2 = (images[t_mask] - target_images[t_mask]).abs().pow(2).mean()
+ loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean()
+
+ loss = loss_target_l2 + loss_alpha_target_mask_l2 + loss_expand * loss_expansion_weight
+
+ # out of box
+ loss_oob = (vertices.abs() > 0.99).float().mean() * 10
+ loss = loss + loss_oob
+
+ loss.backward()
+ opt.step()
+
+ vertices,faces = opt.remesh(poisson=False)
+
+ vertices, faces = vertices.detach(), faces.detach()
+
+ if return_mesh:
+ return to_py3d_mesh(vertices, faces)
+ else:
+ return vertices, faces
diff --git a/mesh_reconstruction/refine.py b/mesh_reconstruction/refine.py
new file mode 100644
index 0000000000000000000000000000000000000000..d10e244f466ae92882325ffcaa494505535b05e5
--- /dev/null
+++ b/mesh_reconstruction/refine.py
@@ -0,0 +1,77 @@
+from tqdm import tqdm
+from PIL import Image
+import torch
+from typing import List
+from mesh_reconstruction.remesh import calc_vertex_normals
+from mesh_reconstruction.opt import MeshOptimizer
+from mesh_reconstruction.func import make_star_cameras_orthographic
+from mesh_reconstruction.render import NormalsRenderer
+from scripts.project_mesh import multiview_color_projection, get_cameras_list
+from scripts.utils import to_py3d_mesh, from_py3d_mesh, init_target
+
+def run_mesh_refine(vertices, faces, pils: List[Image.Image], steps=100, start_edge_len=0.02, end_edge_len=0.005, decay=0.99, update_normal_interval=10, update_warmup=10, return_mesh=True, process_inputs=True, process_outputs=True):
+ if process_inputs:
+ vertices = vertices * 2 / 1.35
+ vertices[..., [0, 2]] = - vertices[..., [0, 2]]
+
+ poission_steps = []
+
+ assert len(pils) == 4
+ mv,proj = make_star_cameras_orthographic(4, 1)
+ renderer = NormalsRenderer(mv,proj,list(pils[0].size))
+
+ target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
+ # 1. no rotate
+ target_images = target_images[[0, 3, 2, 1]]
+
+ # 2. init from coarse mesh
+ opt = MeshOptimizer(vertices,faces, ramp=5, edge_len_lims=(end_edge_len, start_edge_len), local_edgelen=False, laplacian_weight=0.02)
+
+ vertices = opt.vertices
+ alpha_init = None
+
+ mask = target_images[..., -1] < 0.5
+
+ for i in tqdm(range(steps)):
+ opt.zero_grad()
+ opt._lr *= decay
+ normals = calc_vertex_normals(vertices,faces)
+ images = renderer.render(vertices,normals,faces)
+ if alpha_init is None:
+ alpha_init = images.detach()
+
+ if i < update_warmup or i % update_normal_interval == 0:
+ with torch.no_grad():
+ py3d_mesh = to_py3d_mesh(vertices, faces, normals)
+ cameras = get_cameras_list(azim_list = [0, 90, 180, 270], device=vertices.device, focal=1.)
+ _, _, target_normal = from_py3d_mesh(multiview_color_projection(py3d_mesh, pils, cameras_list=cameras, weights=[2.0, 0.8, 1.0, 0.8], confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy='original', reweight_with_cosangle='linear'))
+ target_normal = target_normal * 2 - 1
+ target_normal = torch.nn.functional.normalize(target_normal, dim=-1)
+ debug_images = renderer.render(vertices,target_normal,faces)
+
+ d_mask = images[..., -1] > 0.5
+ loss_debug_l2 = (images[..., :3][d_mask] - debug_images[..., :3][d_mask]).pow(2).mean()
+
+ loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean()
+
+ loss = loss_debug_l2 + loss_alpha_target_mask_l2
+
+ # out of box
+ loss_oob = (vertices.abs() > 0.99).float().mean() * 10
+ loss = loss + loss_oob
+
+ loss.backward()
+ opt.step()
+
+ vertices,faces = opt.remesh(poisson=(i in poission_steps))
+
+ vertices, faces = vertices.detach(), faces.detach()
+
+ if process_outputs:
+ vertices = vertices / 2 * 1.35
+ vertices[..., [0, 2]] = - vertices[..., [0, 2]]
+
+ if return_mesh:
+ return to_py3d_mesh(vertices, faces)
+ else:
+ return vertices, faces
diff --git a/mesh_reconstruction/remesh.py b/mesh_reconstruction/remesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e2faa83cf5afc98b571254c6cee3893a2396a30
--- /dev/null
+++ b/mesh_reconstruction/remesh.py
@@ -0,0 +1,361 @@
+# modified from https://github.com/Profactor/continuous-remeshing
+import torch
+import torch.nn.functional as tfunc
+import torch_scatter
+from typing import Tuple
+
+def prepend_dummies(
+ vertices:torch.Tensor, #V,D
+ faces:torch.Tensor, #F,3 long
+ )->Tuple[torch.Tensor,torch.Tensor]:
+ """prepend dummy elements to vertices and faces to enable "masked" scatter operations"""
+ V,D = vertices.shape
+ vertices = torch.concat((torch.full((1,D),fill_value=torch.nan,device=vertices.device),vertices),dim=0)
+ faces = torch.concat((torch.zeros((1,3),dtype=torch.long,device=faces.device),faces+1),dim=0)
+ return vertices,faces
+
+def remove_dummies(
+ vertices:torch.Tensor, #V,D - first vertex all nan and unreferenced
+ faces:torch.Tensor, #F,3 long - first face all zeros
+ )->Tuple[torch.Tensor,torch.Tensor]:
+ """remove dummy elements added with prepend_dummies()"""
+ return vertices[1:],faces[1:]-1
+
+
+def calc_edges(
+ faces: torch.Tensor, # F,3 long - first face may be dummy with all zeros
+ with_edge_to_face: bool = False
+ ) -> Tuple[torch.Tensor, ...]:
+ """
+ returns Tuple of
+ - edges E,2 long, 0 for unused, lower vertex index first
+ - face_to_edge F,3 long
+ - (optional) edge_to_face shape=E,[left,right],[face,side]
+
+ o-<-----e1 e0,e1...edge, e0-o
+ """
+
+ F = faces.shape[0]
+
+ # make full edges, lower vertex index first
+ face_edges = torch.stack((faces,faces.roll(-1,1)),dim=-1) #F*3,3,2
+ full_edges = face_edges.reshape(F*3,2)
+ sorted_edges,_ = full_edges.sort(dim=-1) #F*3,2
+
+ # make unique edges
+ edges,full_to_unique = torch.unique(input=sorted_edges,sorted=True,return_inverse=True,dim=0) #(E,2),(F*3)
+ E = edges.shape[0]
+ face_to_edge = full_to_unique.reshape(F,3) #F,3
+
+ if not with_edge_to_face:
+ return edges, face_to_edge
+
+ is_right = full_edges[:,0]!=sorted_edges[:,0] #F*3
+ edge_to_face = torch.zeros((E,2,2),dtype=torch.long,device=faces.device) #E,LR=2,S=2
+ scatter_src = torch.cartesian_prod(torch.arange(0,F,device=faces.device),torch.arange(0,3,device=faces.device)) #F*3,2
+ edge_to_face.reshape(2*E,2).scatter_(dim=0,index=(2*full_to_unique+is_right)[:,None].expand(F*3,2),src=scatter_src) #E,LR=2,S=2
+ edge_to_face[0] = 0
+ return edges, face_to_edge, edge_to_face
+
+def calc_edge_length(
+ vertices:torch.Tensor, #V,3 first may be dummy
+ edges:torch.Tensor, #E,2 long, lower vertex index first, (0,0) for unused
+ )->torch.Tensor: #E
+
+ full_vertices = vertices[edges] #E,2,3
+ a,b = full_vertices.unbind(dim=1) #E,3
+ return torch.norm(a-b,p=2,dim=-1)
+
+def calc_face_normals(
+ vertices:torch.Tensor, #V,3 first vertex may be unreferenced
+ faces:torch.Tensor, #F,3 long, first face may be all zero
+ normalize:bool=False,
+ )->torch.Tensor: #F,3
+ """
+ n
+ |
+ c0 corners ordered counterclockwise when
+ / \ looking onto surface (in neg normal direction)
+ c1---c2
+ """
+ full_vertices = vertices[faces] #F,C=3,3
+ v0,v1,v2 = full_vertices.unbind(dim=1) #F,3
+ face_normals = torch.cross(v1-v0,v2-v0, dim=1) #F,3
+ if normalize:
+ face_normals = tfunc.normalize(face_normals, eps=1e-6, dim=1)
+ return face_normals #F,3
+
+def calc_vertex_normals(
+ vertices:torch.Tensor, #V,3 first vertex may be unreferenced
+ faces:torch.Tensor, #F,3 long, first face may be all zero
+ face_normals:torch.Tensor=None, #F,3, not normalized
+ )->torch.Tensor: #F,3
+
+ F = faces.shape[0]
+
+ if face_normals is None:
+ face_normals = calc_face_normals(vertices,faces)
+
+ vertex_normals = torch.zeros((vertices.shape[0],3,3),dtype=vertices.dtype,device=vertices.device) #V,C=3,3
+ vertex_normals.scatter_add_(dim=0,index=faces[:,:,None].expand(F,3,3),src=face_normals[:,None,:].expand(F,3,3))
+ vertex_normals = vertex_normals.sum(dim=1) #V,3
+ return tfunc.normalize(vertex_normals, eps=1e-6, dim=1)
+
+def calc_face_ref_normals(
+ faces:torch.Tensor, #F,3 long, 0 for unused
+ vertex_normals:torch.Tensor, #V,3 first unused
+ normalize:bool=False,
+ )->torch.Tensor: #F,3
+ """calculate reference normals for face flip detection"""
+ full_normals = vertex_normals[faces] #F,C=3,3
+ ref_normals = full_normals.sum(dim=1) #F,3
+ if normalize:
+ ref_normals = tfunc.normalize(ref_normals, eps=1e-6, dim=1)
+ return ref_normals
+
+def pack(
+ vertices:torch.Tensor, #V,3 first unused and nan
+ faces:torch.Tensor, #F,3 long, 0 for unused
+ )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces), keeps first vertex unused
+ """removes unused elements in vertices and faces"""
+ V = vertices.shape[0]
+
+ # remove unused faces
+ used_faces = faces[:,0]!=0
+ used_faces[0] = True
+ faces = faces[used_faces] #sync
+
+ # remove unused vertices
+ used_vertices = torch.zeros(V,3,dtype=torch.bool,device=vertices.device)
+ used_vertices.scatter_(dim=0,index=faces,value=True,reduce='add')
+ used_vertices = used_vertices.any(dim=1)
+ used_vertices[0] = True
+ vertices = vertices[used_vertices] #sync
+
+ # update used faces
+ ind = torch.zeros(V,dtype=torch.long,device=vertices.device)
+ V1 = used_vertices.sum()
+ ind[used_vertices] = torch.arange(0,V1,device=vertices.device) #sync
+ faces = ind[faces]
+
+ return vertices,faces
+
+def split_edges(
+ vertices:torch.Tensor, #V,3 first unused
+ faces:torch.Tensor, #F,3 long, 0 for unused
+ edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
+ face_to_edge:torch.Tensor, #F,3 long 0 for unused
+ splits, #E bool
+ pack_faces:bool=True,
+ )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
+
+ # c2 c2 c...corners = faces
+ # . . . . s...side_vert, 0 means no split
+ # . . .N2 . S...shrunk_face
+ # . . . . Ni...new_faces
+ # s2 s1 s2|c2...s1|c1
+ # . . . . .
+ # . . . S . .
+ # . . . . N1 .
+ # c0...(s0=0)....c1 s0|c0...........c1
+ #
+ # pseudo-code:
+ # S = [s0|c0,s1|c1,s2|c2] example:[c0,s1,s2]
+ # split = side_vert!=0 example:[False,True,True]
+ # N0 = split[0]*[c0,s0,s2|c2] example:[0,0,0]
+ # N1 = split[1]*[c1,s1,s0|c0] example:[c1,s1,c0]
+ # N2 = split[2]*[c2,s2,s1|c1] example:[c2,s2,s1]
+
+ V = vertices.shape[0]
+ F = faces.shape[0]
+ S = splits.sum().item() #sync
+
+ if S==0:
+ return vertices,faces
+
+ edge_vert = torch.zeros_like(splits, dtype=torch.long) #E
+ edge_vert[splits] = torch.arange(V,V+S,dtype=torch.long,device=vertices.device) #E 0 for no split, sync
+ side_vert = edge_vert[face_to_edge] #F,3 long, 0 for no split
+ split_edges = edges[splits] #S sync
+
+ #vertices
+ split_vertices = vertices[split_edges].mean(dim=1) #S,3
+ vertices = torch.concat((vertices,split_vertices),dim=0)
+
+ #faces
+ side_split = side_vert!=0 #F,3
+ shrunk_faces = torch.where(side_split,side_vert,faces) #F,3 long, 0 for no split
+ new_faces = side_split[:,:,None] * torch.stack((faces,side_vert,shrunk_faces.roll(1,dims=-1)),dim=-1) #F,N=3,C=3
+ faces = torch.concat((shrunk_faces,new_faces.reshape(F*3,3))) #4F,3
+ if pack_faces:
+ mask = faces[:,0]!=0
+ mask[0] = True
+ faces = faces[mask] #F',3 sync
+
+ return vertices,faces
+
+def collapse_edges(
+ vertices:torch.Tensor, #V,3 first unused
+ faces:torch.Tensor, #F,3 long 0 for unused
+ edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
+ priorities:torch.Tensor, #E float
+ stable:bool=False, #only for unit testing
+ )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
+
+ V = vertices.shape[0]
+
+ # check spacing
+ _,order = priorities.sort(stable=stable) #E
+ rank = torch.zeros_like(order)
+ rank[order] = torch.arange(0,len(rank),device=rank.device)
+ vert_rank = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
+ edge_rank = rank #E
+ for i in range(3):
+ torch_scatter.scatter_max(src=edge_rank[:,None].expand(-1,2).reshape(-1),index=edges.reshape(-1),dim=0,out=vert_rank)
+ edge_rank,_ = vert_rank[edges].max(dim=-1) #E
+ candidates = edges[(edge_rank==rank).logical_and_(priorities>0)] #E',2
+
+ # check connectivity
+ vert_connections = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
+ vert_connections[candidates[:,0]] = 1 #start
+ edge_connections = vert_connections[edges].sum(dim=-1) #E, edge connected to start
+ vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1))# one edge from start
+ vert_connections[candidates] = 0 #clear start and end
+ edge_connections = vert_connections[edges].sum(dim=-1) #E, one or two edges from start
+ vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1)) #one or two edges from start
+ collapses = candidates[vert_connections[candidates[:,1]] <= 2] # E" not more than two connections between start and end
+
+ # mean vertices
+ vertices[collapses[:,0]] = vertices[collapses].mean(dim=1)
+
+ # update faces
+ dest = torch.arange(0,V,dtype=torch.long,device=vertices.device) #V
+ dest[collapses[:,1]] = dest[collapses[:,0]]
+ faces = dest[faces] #F,3
+ c0,c1,c2 = faces.unbind(dim=-1)
+ collapsed = (c0==c1).logical_or_(c1==c2).logical_or_(c0==c2)
+ faces[collapsed] = 0
+
+ return vertices,faces
+
+def calc_face_collapses(
+ vertices:torch.Tensor, #V,3 first unused
+ faces:torch.Tensor, #F,3 long, 0 for unused
+ edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
+ face_to_edge:torch.Tensor, #F,3 long 0 for unused
+ edge_length:torch.Tensor, #E
+ face_normals:torch.Tensor, #F,3
+ vertex_normals:torch.Tensor, #V,3 first unused
+ min_edge_length:torch.Tensor=None, #V
+ area_ratio = 0.5, #collapse if area < min_edge_length**2 * area_ratio
+ shortest_probability = 0.8
+ )->torch.Tensor: #E edges to collapse
+
+ E = edges.shape[0]
+ F = faces.shape[0]
+
+ # face flips
+ ref_normals = calc_face_ref_normals(faces,vertex_normals,normalize=False) #F,3
+ face_collapses = (face_normals*ref_normals).sum(dim=-1)<0 #F
+
+ # small faces
+ if min_edge_length is not None:
+ min_face_length = min_edge_length[faces].mean(dim=-1) #F
+ min_area = min_face_length**2 * area_ratio #F
+ face_collapses.logical_or_(face_normals.norm(dim=-1) < min_area*2) #F
+ face_collapses[0] = False
+
+ # faces to edges
+ face_length = edge_length[face_to_edge] #F,3
+
+ if shortest_probability<1:
+ #select shortest edge with shortest_probability chance
+ randlim = round(2/(1-shortest_probability))
+ rand_ind = torch.randint(0,randlim,size=(F,),device=faces.device).clamp_max_(2) #selected edge local index in face
+ sort_ind = torch.argsort(face_length,dim=-1,descending=True) #F,3
+ local_ind = sort_ind.gather(dim=-1,index=rand_ind[:,None])
+ else:
+ local_ind = torch.argmin(face_length,dim=-1)[:,None] #F,1 0...2 shortest edge local index in face
+
+ edge_ind = face_to_edge.gather(dim=1,index=local_ind)[:,0] #F 0...E selected edge global index
+ edge_collapses = torch.zeros(E,dtype=torch.long,device=vertices.device)
+ edge_collapses.scatter_add_(dim=0,index=edge_ind,src=face_collapses.long())
+
+ return edge_collapses.bool()
+
+def flip_edges(
+ vertices:torch.Tensor, #V,3 first unused
+ faces:torch.Tensor, #F,3 long, first must be 0, 0 for unused
+ edges:torch.Tensor, #E,2 long, first must be 0, 0 for unused, lower vertex index first
+ edge_to_face:torch.Tensor, #E,[left,right],[face,side]
+ with_border:bool=True, #handle border edges (D=4 instead of D=6)
+ with_normal_check:bool=True, #check face normal flips
+ stable:bool=False, #only for unit testing
+ ):
+ V = vertices.shape[0]
+ E = edges.shape[0]
+ device=vertices.device
+ vertex_degree = torch.zeros(V,dtype=torch.long,device=device) #V long
+ vertex_degree.scatter_(dim=0,index=edges.reshape(E*2),value=1,reduce='add')
+ neighbor_corner = (edge_to_face[:,:,1] + 2) % 3 #go from side to corner
+ neighbors = faces[edge_to_face[:,:,0],neighbor_corner] #E,LR=2
+ edge_is_inside = neighbors.all(dim=-1) #E
+
+ if with_border:
+ # inside vertices should have D=6, border edges D=4, so we subtract 2 for all inside vertices
+ # need to use float for masks in order to use scatter(reduce='multiply')
+ vertex_is_inside = torch.ones(V,2,dtype=torch.float32,device=vertices.device) #V,2 float
+ src = edge_is_inside.type(torch.float32)[:,None].expand(E,2) #E,2 float
+ vertex_is_inside.scatter_(dim=0,index=edges,src=src,reduce='multiply')
+ vertex_is_inside = vertex_is_inside.prod(dim=-1,dtype=torch.long) #V long
+ vertex_degree -= 2 * vertex_is_inside #V long
+
+ neighbor_degrees = vertex_degree[neighbors] #E,LR=2
+ edge_degrees = vertex_degree[edges] #E,2
+ #
+ # loss = Sum_over_affected_vertices((new_degree-6)**2)
+ # loss_change = Sum_over_neighbor_vertices((degree+1-6)**2-(degree-6)**2)
+ # + Sum_over_edge_vertices((degree-1-6)**2-(degree-6)**2)
+ # = 2 * (2 + Sum_over_neighbor_vertices(degree) - Sum_over_edge_vertices(degree))
+ #
+ loss_change = 2 + neighbor_degrees.sum(dim=-1) - edge_degrees.sum(dim=-1) #E
+ candidates = torch.logical_and(loss_change<0, edge_is_inside) #E
+ loss_change = loss_change[candidates] #E'
+ if loss_change.shape[0]==0:
+ return
+
+ edges_neighbors = torch.concat((edges[candidates],neighbors[candidates]),dim=-1) #E',4
+ _,order = loss_change.sort(descending=True, stable=stable) #E'
+ rank = torch.zeros_like(order)
+ rank[order] = torch.arange(0,len(rank),device=rank.device)
+ vertex_rank = torch.zeros((V,4),dtype=torch.long,device=device) #V,4
+ torch_scatter.scatter_max(src=rank[:,None].expand(-1,4),index=edges_neighbors,dim=0,out=vertex_rank)
+ vertex_rank,_ = vertex_rank.max(dim=-1) #V
+ neighborhood_rank,_ = vertex_rank[edges_neighbors].max(dim=-1) #E'
+ flip = rank==neighborhood_rank #E'
+
+ if with_normal_check:
+ # cl-<-----e1 e0,e1...edge, e0-cr
+ v = vertices[edges_neighbors] #E",4,3
+ v = v - v[:,0:1] #make relative to e0
+ e1 = v[:,1]
+ cl = v[:,2]
+ cr = v[:,3]
+ n = torch.cross(e1,cl) + torch.cross(cr,e1) #sum of old normal vectors
+ flip.logical_and_(torch.sum(n*torch.cross(cr,cl),dim=-1)>0) #first new face
+ flip.logical_and_(torch.sum(n*torch.cross(cl-e1,cr-e1),dim=-1)>0) #second new face
+
+ flip_edges_neighbors = edges_neighbors[flip] #E",4
+ flip_edge_to_face = edge_to_face[candidates,:,0][flip] #E",2
+ flip_faces = flip_edges_neighbors[:,[[0,3,2],[1,2,3]]] #E",2,3
+ faces.scatter_(dim=0,index=flip_edge_to_face.reshape(-1,1).expand(-1,3),src=flip_faces.reshape(-1,3))
diff --git a/mesh_reconstruction/render.py b/mesh_reconstruction/render.py
new file mode 100644
index 0000000000000000000000000000000000000000..b09b0227a41322704fc4f328dc65087721c66945
--- /dev/null
+++ b/mesh_reconstruction/render.py
@@ -0,0 +1,53 @@
+# modified from https://github.com/Profactor/continuous-remeshing
+import nvdiffrast.torch as dr
+import torch
+from typing import Tuple
+
+def _warmup(glctx, device=None):
+ device = 'cuda' if device is None else device
+ #windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59
+ def tensor(*args, **kwargs):
+ return torch.tensor(*args, device=device, **kwargs)
+ pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32)
+ tri = tensor([[0, 1, 2]], dtype=torch.int32)
+ dr.rasterize(glctx, pos, tri, resolution=[256, 256])
+
+glctx = dr.RasterizeGLContext(output_db=False, device="cuda")
+
+class NormalsRenderer:
+
+ _glctx:dr.RasterizeGLContext = None
+
+ def __init__(
+ self,
+ mv: torch.Tensor, #C,4,4
+ proj: torch.Tensor, #C,4,4
+ image_size: Tuple[int,int],
+ mvp = None,
+ device=None,
+ ):
+ if mvp is None:
+ self._mvp = proj @ mv #C,4,4
+ else:
+ self._mvp = mvp
+ self._image_size = image_size
+ self._glctx = glctx
+ _warmup(self._glctx, device)
+
+ def render(self,
+ vertices: torch.Tensor, #V,3 float
+ normals: torch.Tensor, #V,3 float in [-1, 1]
+ faces: torch.Tensor, #F,3 long
+ ) ->torch.Tensor: #C,H,W,4
+
+ V = vertices.shape[0]
+ faces = faces.type(torch.int32)
+ vert_hom = torch.cat((vertices, torch.ones(V,1,device=vertices.device)),axis=-1) #V,3 -> V,4
+ vertices_clip = vert_hom @ self._mvp.transpose(-2,-1) #C,V,4
+ rast_out,_ = dr.rasterize(self._glctx, vertices_clip, faces, resolution=self._image_size, grad_db=False) #C,H,W,4
+ vert_col = (normals+1)/2 #V,3
+ col,_ = dr.interpolate(vert_col, rast_out, faces) #C,H,W,3
+ alpha = torch.clamp(rast_out[..., -1:], max=1) #C,H,W,1
+ col = torch.concat((col,alpha),dim=-1) #C,H,W,4
+ col = dr.antialias(col, rast_out, vertices_clip, faces) #C,H,W,4
+ return col #C,H,W,4
diff --git a/requirements-detail.txt b/requirements-detail.txt
new file mode 100644
index 0000000000000000000000000000000000000000..cd0bc9a8425cadb2ad81cd178778484f59fcefda
--- /dev/null
+++ b/requirements-detail.txt
@@ -0,0 +1,27 @@
+accelerate==0.29.2
+datasets==2.18.0
+diffusers==0.27.2
+fire==0.6.0
+gradio==4.32.0
+jaxtyping==0.2.29
+numba==0.59.1
+numpy==1.26.4
+nvdiffrast==0.3.1
+omegaconf==2.3.0
+onnxruntime_gpu==1.17.0
+opencv_python==4.9.0.80
+opencv_python_headless==4.9.0.80
+ort_nightly_gpu==1.17.0.dev20240118002
+peft==0.10.0
+Pillow==10.3.0
+pygltflib==1.16.2
+pymeshlab==2023.12.post1
+pytorch3d==0.7.5
+rembg==2.0.56
+torch==2.1.0+cu121
+torch_scatter==2.1.2
+tqdm==4.64.1
+transformers==4.39.3
+trimesh==4.3.0
+typeguard==2.13.3
+wandb==0.16.6
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0bd717ca68a574b9ee2bf7159198f16e984e93d4
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,28 @@
+accelerate
+datasets
+diffusers>=0.26.3
+fire
+gradio
+jaxtyping
+numba
+numpy
+git+https://github.com/NVlabs/nvdiffrast.git
+omegaconf>=2.3.0
+onnxruntime_gpu
+opencv_python
+opencv_python_headless
+ort_nightly_gpu
+peft
+Pillow
+pygltflib
+pymeshlab>=2023.12
+git+https://github.com/facebookresearch/pytorch3d.git@stable
+rembg
+torch>=2.0.1
+torch_scatter
+tqdm
+transformers
+trimesh
+typeguard
+wandb
+xformers
diff --git a/scripts/all_typing.py b/scripts/all_typing.py
new file mode 100644
index 0000000000000000000000000000000000000000..87b19aaefff18dc04a8a7c185cd9086a27e91c62
--- /dev/null
+++ b/scripts/all_typing.py
@@ -0,0 +1,42 @@
+# code from https://github.com/threestudio-project
+
+"""
+This module contains type annotations for the project, using
+1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects
+2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors
+
+Two types of typing checking can be used:
+1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)
+2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)
+"""
+
+# Basic types
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Literal,
+ NamedTuple,
+ NewType,
+ Optional,
+ Sized,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
+
+# Tensor dtype
+# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
+from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
+
+# Config type
+from omegaconf import DictConfig
+
+# PyTorch Tensor type
+from torch import Tensor
+
+# Runtime type checking decorator
+from typeguard import typechecked as typechecker
diff --git a/scripts/load_onnx.py b/scripts/load_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..954cd444b988d372314cc1c7983bfc9dca4e998e
--- /dev/null
+++ b/scripts/load_onnx.py
@@ -0,0 +1,48 @@
+import onnxruntime
+import torch
+
+providers = [
+ ('TensorrtExecutionProvider', {
+ 'device_id': 0,
+ 'trt_max_workspace_size': 8 * 1024 * 1024 * 1024,
+ 'trt_fp16_enable': True,
+ 'trt_engine_cache_enable': True,
+ }),
+ ('CUDAExecutionProvider', {
+ 'device_id': 0,
+ 'arena_extend_strategy': 'kSameAsRequested',
+ 'gpu_mem_limit': 8 * 1024 * 1024 * 1024,
+ 'cudnn_conv_algo_search': 'HEURISTIC',
+ })
+]
+
+def load_onnx(file_path: str):
+ assert file_path.endswith(".onnx")
+ sess_opt = onnxruntime.SessionOptions()
+ ort_session = onnxruntime.InferenceSession(file_path, sess_opt=sess_opt, providers=providers)
+ return ort_session
+
+
+def load_onnx_caller(file_path: str, single_output=False):
+ ort_session = load_onnx(file_path)
+ def caller(*args):
+ torch_input = isinstance(args[0], torch.Tensor)
+ if torch_input:
+ torch_input_dtype = args[0].dtype
+ torch_input_device = args[0].device
+ # check all are torch.Tensor and have same dtype and device
+ assert all([isinstance(arg, torch.Tensor) for arg in args]), "All inputs should be torch.Tensor, if first input is torch.Tensor"
+ assert all([arg.dtype == torch_input_dtype for arg in args]), "All inputs should have same dtype, if first input is torch.Tensor"
+ assert all([arg.device == torch_input_device for arg in args]), "All inputs should have same device, if first input is torch.Tensor"
+ args = [arg.cpu().float().numpy() for arg in args]
+
+ ort_inputs = {ort_session.get_inputs()[idx].name: args[idx] for idx in range(len(args))}
+ ort_outs = ort_session.run(None, ort_inputs)
+
+ if torch_input:
+ ort_outs = [torch.tensor(ort_out, dtype=torch_input_dtype, device=torch_input_device) for ort_out in ort_outs]
+
+ if single_output:
+ return ort_outs[0]
+ return ort_outs
+ return caller
diff --git a/scripts/mesh_init.py b/scripts/mesh_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a70af530f71a0c39bfe7f99ecb28c6103bf84ec
--- /dev/null
+++ b/scripts/mesh_init.py
@@ -0,0 +1,132 @@
+from PIL import Image
+import torch
+import numpy as np
+from pytorch3d.structures import Meshes
+from pytorch3d.renderer import TexturesVertex
+from scripts.utils import meshlab_mesh_to_py3dmesh, py3dmesh_to_meshlab_mesh
+import pymeshlab
+
+_MAX_THREAD = 8
+
+# rgb and depth to mesh
+def get_ortho_ray_directions_origins(W, H, use_pixel_centers=True, device="cuda"):
+ pixel_center = 0.5 if use_pixel_centers else 0
+ i, j = np.meshgrid(
+ np.arange(W, dtype=np.float32) + pixel_center,
+ np.arange(H, dtype=np.float32) + pixel_center,
+ indexing='xy'
+ )
+ i, j = torch.from_numpy(i).to(device), torch.from_numpy(j).to(device)
+
+ origins = torch.stack([(i/W-0.5)*2, (j/H-0.5)*2 * H / W, torch.zeros_like(i)], dim=-1) # W, H, 3
+ directions = torch.stack([torch.zeros_like(i), torch.zeros_like(j), torch.ones_like(i)], dim=-1) # W, H, 3
+
+ return origins, directions
+
+def depth_and_color_to_mesh(rgb_BCHW, pred_HWC, valid_HWC=None, is_back=False):
+ if valid_HWC is None:
+ valid_HWC = torch.ones_like(pred_HWC).bool()
+ H, W = rgb_BCHW.shape[-2:]
+ rgb_BCHW = rgb_BCHW.flip(-2)
+ pred_HWC = pred_HWC.flip(0)
+ valid_HWC = valid_HWC.flip(0)
+ rays_o, rays_d = get_ortho_ray_directions_origins(W, H, device=rgb_BCHW.device)
+ verts = rays_o + rays_d * pred_HWC # [H, W, 3]
+ verts = verts.reshape(-1, 3) # [V, 3]
+ indexes = torch.arange(H * W).reshape(H, W).to(rgb_BCHW.device)
+ faces1 = torch.stack([indexes[:-1, :-1], indexes[:-1, 1:], indexes[1:, :-1]], dim=-1)
+ # faces1_valid = valid_HWC[:-1, :-1] | valid_HWC[:-1, 1:] | valid_HWC[1:, :-1]
+ faces1_valid = valid_HWC[:-1, :-1] & valid_HWC[:-1, 1:] & valid_HWC[1:, :-1]
+ faces2 = torch.stack([indexes[1:, 1:], indexes[1:, :-1], indexes[:-1, 1:]], dim=-1)
+ # faces2_valid = valid_HWC[1:, 1:] | valid_HWC[1:, :-1] | valid_HWC[:-1, 1:]
+ faces2_valid = valid_HWC[1:, 1:] & valid_HWC[1:, :-1] & valid_HWC[:-1, 1:]
+ faces = torch.cat([faces1[faces1_valid.expand_as(faces1)].reshape(-1, 3), faces2[faces2_valid.expand_as(faces2)].reshape(-1, 3)], dim=0) # (F, 3)
+ colors = (rgb_BCHW[0].permute((1,2,0)) / 2 + 0.5).reshape(-1, 3) # (V, 3)
+ if is_back:
+ verts = verts * torch.tensor([-1, 1, -1], dtype=verts.dtype, device=verts.device)
+
+ used_verts = faces.unique()
+ old_to_new_mapping = torch.zeros_like(verts[..., 0]).long()
+ old_to_new_mapping[used_verts] = torch.arange(used_verts.shape[0], device=verts.device)
+ new_faces = old_to_new_mapping[faces]
+ mesh = Meshes(verts=[verts[used_verts]], faces=[new_faces], textures=TexturesVertex(verts_features=[colors[used_verts]]))
+ return mesh
+
+def normalmap_to_depthmap(normal_np):
+ from scripts.normal_to_height_map import estimate_height_map
+ height = estimate_height_map(normal_np, raw_values=True, thread_count=_MAX_THREAD, target_iteration_count=96)
+ return height
+
+def transform_back_normal_to_front(normal_pil):
+ arr = np.array(normal_pil) # in [0, 255]
+ arr[..., 0] = 255-arr[..., 0]
+ arr[..., 2] = 255-arr[..., 2]
+ return Image.fromarray(arr.astype(np.uint8))
+
+def calc_w_over_h(normal_pil):
+ if isinstance(normal_pil, Image.Image):
+ arr = np.array(normal_pil)
+ else:
+ assert isinstance(normal_pil, np.ndarray)
+ arr = normal_pil
+ if arr.shape[-1] == 4:
+ alpha = arr[..., -1] / 255.
+ alpha[alpha >= 0.5] = 1
+ alpha[alpha < 0.5] = 0
+ else:
+ alpha = ~(arr.min(axis=-1) >= 250)
+ h_min, w_min = np.min(np.where(alpha), axis=1)
+ h_max, w_max = np.max(np.where(alpha), axis=1)
+ return (w_max - w_min) / (h_max - h_min)
+
+def build_mesh(normal_pil, rgb_pil, is_back=False, clamp_min=-1, scale=0.3, init_type="std", offset=0):
+ if is_back:
+ normal_pil = transform_back_normal_to_front(normal_pil)
+ normal_img = np.array(normal_pil)
+ rgb_img = np.array(rgb_pil)
+ if normal_img.shape[-1] == 4:
+ valid_HWC = normal_img[..., [3]] / 255
+ elif rgb_img.shape[-1] == 4:
+ valid_HWC = rgb_img[..., [3]] / 255
+ else:
+ raise ValueError("invalid input, either normal or rgb should have alpha channel")
+
+ real_height_pix = np.max(np.where(valid_HWC>0.5)[0]) - np.min(np.where(valid_HWC>0.5)[0])
+
+ heights = normalmap_to_depthmap(normal_img)
+ rgb_BCHW = torch.from_numpy(rgb_img[..., :3] / 255.).permute((2,0,1))[None]
+ valid_HWC[valid_HWC < 0.5] = 0
+ valid_HWC[valid_HWC >= 0.5] = 1
+ valid_HWC = torch.from_numpy(valid_HWC).bool()
+ if init_type == "std":
+ # accurate but not stable
+ pred_HWC = torch.from_numpy(heights / heights.max() * (real_height_pix / heights.shape[0]) * scale * 2).float()[..., None]
+ elif init_type == "thin":
+ heights = heights - heights.min()
+ heights = (heights / heights.max() * 0.2)
+ pred_HWC = torch.from_numpy(heights * scale).float()[..., None]
+ else:
+ # stable but not accurate
+ heights = heights - heights.min()
+ heights = (heights / heights.max() * (1-offset)) + offset # to [0.2, 1]
+ pred_HWC = torch.from_numpy(heights * scale).float()[..., None]
+
+ # set the boarder pixels to 0 height
+ import cv2
+ # edge filter
+ edge = cv2.Canny((valid_HWC[..., 0] * 255).numpy().astype(np.uint8), 0, 255)
+ edge = torch.from_numpy(edge).bool()[..., None]
+ pred_HWC[edge] = 0
+
+ valid_HWC[pred_HWC < clamp_min] = False
+ return depth_and_color_to_mesh(rgb_BCHW.cuda(), pred_HWC.cuda(), valid_HWC.cuda(), is_back)
+
+def fix_border_with_pymeshlab_fast(meshes: Meshes, poissson_depth=6, simplification=0):
+ ms = pymeshlab.MeshSet()
+ ms.add_mesh(py3dmesh_to_meshlab_mesh(meshes), "cube_vcolor_mesh")
+ if simplification > 0:
+ ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True)
+ ms.apply_filter('generate_surface_reconstruction_screened_poisson', threads = 6, depth = poissson_depth, preclean = True)
+ if simplification > 0:
+ ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True)
+ return meshlab_mesh_to_py3dmesh(ms.current_mesh())
diff --git a/scripts/multiview_inference.py b/scripts/multiview_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d620b73fda27607ae8f9ecd4c6299b59f54ca6c
--- /dev/null
+++ b/scripts/multiview_inference.py
@@ -0,0 +1,98 @@
+import os
+from PIL import Image
+from scripts.mesh_init import build_mesh, calc_w_over_h, fix_border_with_pymeshlab_fast
+from scripts.project_mesh import multiview_color_projection
+from scripts.refine_lr_to_sr import run_sr_fast
+from scripts.utils import simple_clean_mesh
+from app.utils import simple_remove, split_image
+from app.custom_models.normal_prediction import predict_normals
+from mesh_reconstruction.recon import reconstruct_stage1
+from mesh_reconstruction.refine import run_mesh_refine
+from scripts.project_mesh import get_cameras_list
+from scripts.utils import from_py3d_mesh, to_pyml_mesh
+from pytorch3d.structures import Meshes, join_meshes_as_scene
+import numpy as np
+
+def fast_geo(front_normal: Image.Image, back_normal: Image.Image, side_normal: Image.Image, clamp=0., init_type="std"):
+ import time
+ if front_normal.mode == "RGB":
+ front_normal = simple_remove(front_normal, run_sr=False)
+ front_normal = front_normal.resize((192, 192))
+ if back_normal.mode == "RGB":
+ back_normal = simple_remove(back_normal, run_sr=False)
+ back_normal = back_normal.resize((192, 192))
+ if side_normal.mode == "RGB":
+ side_normal = simple_remove(side_normal, run_sr=False)
+ side_normal = side_normal.resize((192, 192))
+
+ # build mesh with front back projection # ~3s
+ side_w_over_h = calc_w_over_h(side_normal)
+ mesh_front = build_mesh(front_normal, front_normal, clamp_min=clamp, scale=side_w_over_h, init_type=init_type)
+ mesh_back = build_mesh(back_normal, back_normal, is_back=True, clamp_min=clamp, scale=side_w_over_h, init_type=init_type)
+ meshes = join_meshes_as_scene([mesh_front, mesh_back])
+ meshes = fix_border_with_pymeshlab_fast(meshes, poissson_depth=6, simplification=2000)
+ return meshes
+
+def refine_rgb(rgb_pils, front_pil):
+ from scripts.refine_lr_to_sr import refine_lr_with_sd
+ from scripts.utils import NEG_PROMPT
+ from app.utils import make_image_grid
+ from app.all_models import model_zoo
+ from app.utils import rgba_to_rgb
+ rgb_pil = make_image_grid(rgb_pils, rows=2)
+ prompt = "4views, multiview"
+ neg_prompt = NEG_PROMPT
+ control_image = rgb_pil.resize((1024, 1024))
+ refined_rgb = refine_lr_with_sd([rgb_pil], [rgba_to_rgb(front_pil)], [control_image], prompt_list=[prompt], neg_prompt_list=[neg_prompt], pipe=model_zoo.pipe_disney_controlnet_tile_ipadapter_i2i, strength=0.2, output_size=(1024, 1024))[0]
+ refined_rgbs = split_image(refined_rgb, rows=2)
+ return refined_rgbs
+
+def erode_alpha(img_list):
+ out_img_list = []
+ for idx, img in enumerate(img_list):
+ arr = np.array(img)
+ alpha = (arr[:, :, 3] > 127).astype(np.uint8)
+ # erode 1px
+ import cv2
+ alpha = cv2.erode(alpha, np.ones((3, 3), np.uint8), iterations=1)
+ alpha = (alpha * 255).astype(np.uint8)
+ img = Image.fromarray(np.concatenate([arr[:, :, :3], alpha[:, :, None]], axis=-1))
+ out_img_list.append(img)
+ return out_img_list
+import time
+def geo_reconstruct(rgb_pils, normal_pils, front_pil, do_refine=False, predict_normal=True, expansion_weight=0.1, init_type="std"):
+ if front_pil.size[0] <= 512:
+ front_pil = run_sr_fast([front_pil])[0]
+ if do_refine:
+ refined_rgbs = refine_rgb(rgb_pils, front_pil) # 6s
+ else:
+ refined_rgbs = [rgb.resize((512, 512), resample=Image.LANCZOS) for rgb in rgb_pils]
+ img_list = [front_pil] + run_sr_fast(refined_rgbs[1:])
+
+ if predict_normal:
+ rm_normals = predict_normals([img.resize((512, 512), resample=Image.LANCZOS) for img in img_list], guidance_scale=1.5)
+ else:
+ rm_normals = simple_remove([img.resize((512, 512), resample=Image.LANCZOS) for img in normal_pils])
+ # transfer the alpha channel of rm_normals to img_list
+ for idx, img in enumerate(rm_normals):
+ if idx == 0 and img_list[0].mode == "RGBA":
+ temp = img_list[0].resize((2048, 2048))
+ rm_normals[0] = Image.fromarray(np.concatenate([np.array(rm_normals[0])[:, :, :3], np.array(temp)[:, :, 3:4]], axis=-1))
+ continue
+ img_list[idx] = Image.fromarray(np.concatenate([np.array(img_list[idx]), np.array(img)[:, :, 3:4]], axis=-1))
+ assert img_list[0].mode == "RGBA"
+ assert np.mean(np.array(img_list[0])[..., 3]) < 250
+
+ img_list = [img_list[0]] + erode_alpha(img_list[1:])
+ normal_stg1 = [img.resize((512, 512)) for img in rm_normals]
+ if init_type in ["std", "thin"]:
+ meshes = fast_geo(normal_stg1[0], normal_stg1[2], normal_stg1[1], init_type=init_type)
+ _ = multiview_color_projection(meshes, rgb_pils, resolution=512, device="cuda", complete_unseen=False, confidence_threshold=0.1) # just check for validation, may throw error
+ vertices, faces, _ = from_py3d_mesh(meshes)
+ vertices, faces = reconstruct_stage1(normal_stg1, steps=200, vertices=vertices, faces=faces, start_edge_len=0.1, end_edge_len=0.02, gain=0.05, return_mesh=False, loss_expansion_weight=expansion_weight)
+ elif init_type in ["ball"]:
+ vertices, faces = reconstruct_stage1(normal_stg1, steps=200, end_edge_len=0.01, return_mesh=False, loss_expansion_weight=expansion_weight)
+ vertices, faces = run_mesh_refine(vertices, faces, rm_normals, steps=100, start_edge_len=0.02, end_edge_len=0.005, decay=0.99, update_normal_interval=20, update_warmup=5, return_mesh=False, process_inputs=False, process_outputs=False)
+ meshes = simple_clean_mesh(to_pyml_mesh(vertices, faces), apply_smooth=True, stepsmoothnum=1, apply_sub_divide=True, sub_divide_threshold=0.25).to("cuda")
+ new_meshes = multiview_color_projection(meshes, img_list, resolution=1024, device="cuda", complete_unseen=True, confidence_threshold=0.2, cameras_list = get_cameras_list([0, 90, 180, 270], "cuda", focal=1))
+ return new_meshes
diff --git a/scripts/normal_to_height_map.py b/scripts/normal_to_height_map.py
new file mode 100644
index 0000000000000000000000000000000000000000..9733e9bc771108fc34ebfc670c06467d83347d8a
--- /dev/null
+++ b/scripts/normal_to_height_map.py
@@ -0,0 +1,203 @@
+# code modified from https://github.com/YertleTurtleGit/depth-from-normals
+import numpy as np
+import cv2 as cv
+from multiprocessing.pool import ThreadPool as Pool
+from multiprocessing import cpu_count
+from typing import Tuple, List, Union
+import numba
+
+
+def calculate_gradients(
+ normals: np.ndarray, mask: np.ndarray
+) -> Tuple[np.ndarray, np.ndarray]:
+ horizontal_angle_map = np.arccos(np.clip(normals[:, :, 0], -1, 1))
+ left_gradients = np.zeros(normals.shape[:2])
+ left_gradients[mask != 0] = (1 - np.sin(horizontal_angle_map[mask != 0])) * np.sign(
+ horizontal_angle_map[mask != 0] - np.pi / 2
+ )
+
+ vertical_angle_map = np.arccos(np.clip(normals[:, :, 1], -1, 1))
+ top_gradients = np.zeros(normals.shape[:2])
+ top_gradients[mask != 0] = -(1 - np.sin(vertical_angle_map[mask != 0])) * np.sign(
+ vertical_angle_map[mask != 0] - np.pi / 2
+ )
+
+ return left_gradients, top_gradients
+
+
+@numba.jit(nopython=True)
+def integrate_gradient_field(
+ gradient_field: np.ndarray, axis: int, mask: np.ndarray
+) -> np.ndarray:
+ heights = np.zeros(gradient_field.shape)
+
+ for d1 in numba.prange(heights.shape[1 - axis]):
+ sum_value = 0
+ for d2 in range(heights.shape[axis]):
+ coordinates = (d1, d2) if axis == 1 else (d2, d1)
+
+ if mask[coordinates] != 0:
+ sum_value = sum_value + gradient_field[coordinates]
+ heights[coordinates] = sum_value
+ else:
+ sum_value = 0
+
+ return heights
+
+
+def calculate_heights(
+ left_gradients: np.ndarray, top_gradients, mask: np.ndarray
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ left_heights = integrate_gradient_field(left_gradients, 1, mask)
+ right_heights = np.fliplr(
+ integrate_gradient_field(np.fliplr(-left_gradients), 1, np.fliplr(mask))
+ )
+ top_heights = integrate_gradient_field(top_gradients, 0, mask)
+ bottom_heights = np.flipud(
+ integrate_gradient_field(np.flipud(-top_gradients), 0, np.flipud(mask))
+ )
+ return left_heights, right_heights, top_heights, bottom_heights
+
+
+def combine_heights(*heights: np.ndarray) -> np.ndarray:
+ return np.mean(np.stack(heights, axis=0), axis=0)
+
+
+def rotate(matrix: np.ndarray, angle: float) -> np.ndarray:
+ h, w = matrix.shape[:2]
+ center = (w / 2, h / 2)
+
+ rotation_matrix = cv.getRotationMatrix2D(center, angle, 1.0)
+ corners = cv.transform(
+ np.array([[[0, 0], [w, 0], [w, h], [0, h]]]), rotation_matrix
+ )[0]
+
+ _, _, w, h = cv.boundingRect(corners)
+
+ rotation_matrix[0, 2] += w / 2 - center[0]
+ rotation_matrix[1, 2] += h / 2 - center[1]
+ result = cv.warpAffine(matrix, rotation_matrix, (w, h), flags=cv.INTER_LINEAR)
+
+ return result
+
+
+def rotate_vector_field_normals(normals: np.ndarray, angle: float) -> np.ndarray:
+ angle = np.radians(angle)
+ cos_angle = np.cos(angle)
+ sin_angle = np.sin(angle)
+
+ rotated_normals = np.empty_like(normals)
+ rotated_normals[:, :, 0] = (
+ normals[:, :, 0] * cos_angle - normals[:, :, 1] * sin_angle
+ )
+ rotated_normals[:, :, 1] = (
+ normals[:, :, 0] * sin_angle + normals[:, :, 1] * cos_angle
+ )
+
+ return rotated_normals
+
+
+def centered_crop(image: np.ndarray, target_resolution: Tuple[int, int]) -> np.ndarray:
+ return image[
+ (image.shape[0] - target_resolution[0])
+ // 2 : (image.shape[0] - target_resolution[0])
+ // 2
+ + target_resolution[0],
+ (image.shape[1] - target_resolution[1])
+ // 2 : (image.shape[1] - target_resolution[1])
+ // 2
+ + target_resolution[1],
+ ]
+
+
+def integrate_vector_field(
+ vector_field: np.ndarray,
+ mask: np.ndarray,
+ target_iteration_count: int,
+ thread_count: int,
+) -> np.ndarray:
+ shape = vector_field.shape[:2]
+ angles = np.linspace(0, 90, target_iteration_count, endpoint=False)
+
+ def integrate_vector_field_angles(angles: List[float]) -> np.ndarray:
+ all_combined_heights = np.zeros(shape)
+
+ for angle in angles:
+ rotated_vector_field = rotate_vector_field_normals(
+ rotate(vector_field, angle), angle
+ )
+ rotated_mask = rotate(mask, angle)
+
+ left_gradients, top_gradients = calculate_gradients(
+ rotated_vector_field, rotated_mask
+ )
+ (
+ left_heights,
+ right_heights,
+ top_heights,
+ bottom_heights,
+ ) = calculate_heights(left_gradients, top_gradients, rotated_mask)
+
+ combined_heights = combine_heights(
+ left_heights, right_heights, top_heights, bottom_heights
+ )
+ combined_heights = centered_crop(rotate(combined_heights, -angle), shape)
+ all_combined_heights += combined_heights / len(angles)
+
+ return all_combined_heights
+
+ with Pool(processes=thread_count) as pool:
+ heights = pool.map(
+ integrate_vector_field_angles,
+ np.array(
+ np.array_split(angles, thread_count),
+ dtype=object,
+ ),
+ )
+ pool.close()
+ pool.join()
+
+ isotropic_height = np.zeros(shape)
+ for height in heights:
+ isotropic_height += height / thread_count
+
+ return isotropic_height
+
+
+def estimate_height_map(
+ normal_map: np.ndarray,
+ mask: Union[np.ndarray, None] = None,
+ height_divisor: float = 1,
+ target_iteration_count: int = 250,
+ thread_count: int = cpu_count(),
+ raw_values: bool = False,
+) -> np.ndarray:
+ if mask is None:
+ if normal_map.shape[-1] == 4:
+ mask = normal_map[:, :, 3] / 255
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ else:
+ mask = np.ones(normal_map.shape[:2], dtype=np.uint8)
+
+ normals = ((normal_map[:, :, :3].astype(np.float64) / 255) - 0.5) * 2
+ heights = integrate_vector_field(
+ normals, mask, target_iteration_count, thread_count
+ )
+
+ if raw_values:
+ return heights
+
+ heights /= height_divisor
+ heights[mask > 0] += 1 / 2
+ heights[mask == 0] = 1 / 2
+
+ heights *= 2**16 - 1
+
+ if np.min(heights) < 0 or np.max(heights) > 2**16 - 1:
+ raise OverflowError("Height values are clipping.")
+
+ heights = np.clip(heights, 0, 2**16 - 1)
+ heights = heights.astype(np.uint16)
+
+ return heights
diff --git a/scripts/project_mesh.py b/scripts/project_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..07635c3f104a5888e853233bf20abd86e08a0203
--- /dev/null
+++ b/scripts/project_mesh.py
@@ -0,0 +1,378 @@
+from typing import List
+import torch
+import numpy as np
+from PIL import Image
+from pytorch3d.renderer.cameras import look_at_view_transform, OrthographicCameras, CamerasBase
+from pytorch3d.renderer.mesh.rasterizer import Fragments
+from pytorch3d.structures import Meshes
+from pytorch3d.renderer import (
+ RasterizationSettings,
+ TexturesVertex,
+ FoVPerspectiveCameras,
+ FoVOrthographicCameras,
+)
+from pytorch3d.renderer import MeshRasterizer
+
+def get_camera(world_to_cam, fov_in_degrees=60, focal_length=1 / (2**0.5), cam_type='fov'):
+ # pytorch3d expects transforms as row-vectors, so flip rotation: https://github.com/facebookresearch/pytorch3d/issues/1183
+ R = world_to_cam[:3, :3].t()[None, ...]
+ T = world_to_cam[:3, 3][None, ...]
+ if cam_type == 'fov':
+ camera = FoVPerspectiveCameras(device=world_to_cam.device, R=R, T=T, fov=fov_in_degrees, degrees=True)
+ else:
+ focal_length = 1 / focal_length
+ camera = FoVOrthographicCameras(device=world_to_cam.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
+ return camera
+
+def render_pix2faces_py3d(meshes, cameras, H=512, W=512, blur_radius=0.0, faces_per_pixel=1):
+ """
+ Renders pix2face of visible faces.
+
+ :param mesh: Pytorch3d.structures.Meshes
+ :param cameras: pytorch3d.renderer.Cameras
+ :param H: target image height
+ :param W: target image width
+ :param blur_radius: Float distance in the range [0, 2] used to expand the face
+ bounding boxes for rasterization. Setting blur radius
+ results in blurred edges around the shape instead of a
+ hard boundary. Set to 0 for no blur.
+ :param faces_per_pixel: (int) Number of faces to keep track of per pixel.
+ We return the nearest faces_per_pixel faces along the z-axis.
+ """
+ # Define the settings for rasterization and shading
+ raster_settings = RasterizationSettings(
+ image_size=(H, W),
+ blur_radius=blur_radius,
+ faces_per_pixel=faces_per_pixel
+ )
+ rasterizer=MeshRasterizer(
+ cameras=cameras,
+ raster_settings=raster_settings
+ )
+ fragments: Fragments = rasterizer(meshes, cameras=cameras)
+ return {
+ "pix_to_face": fragments.pix_to_face[..., 0],
+ }
+
+import nvdiffrast.torch as dr
+
+def _warmup(glctx, device=None):
+ device = 'cuda' if device is None else device
+ #windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59
+ def tensor(*args, **kwargs):
+ return torch.tensor(*args, device=device, **kwargs)
+ pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32)
+ tri = tensor([[0, 1, 2]], dtype=torch.int32)
+ dr.rasterize(glctx, pos, tri, resolution=[256, 256])
+
+class Pix2FacesRenderer:
+ def __init__(self, device="cuda"):
+ self._glctx = dr.RasterizeGLContext(output_db=False, device=device)
+ self.device = device
+ _warmup(self._glctx, device)
+
+ def transform_vertices(self, meshes: Meshes, cameras: CamerasBase):
+ vertices = cameras.transform_points_ndc(meshes.verts_padded())
+
+ perspective_correct = cameras.is_perspective()
+ znear = cameras.get_znear()
+ if isinstance(znear, torch.Tensor):
+ znear = znear.min().item()
+ z_clip = None if not perspective_correct or znear is None else znear / 2
+
+ if z_clip:
+ vertices = vertices[vertices[..., 2] >= cameras.get_znear()][None] # clip
+ vertices = vertices * torch.tensor([-1, -1, 1]).to(vertices)
+ vertices = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1).to(torch.float32)
+ return vertices
+
+ def render_pix2faces_nvdiff(self, meshes: Meshes, cameras: CamerasBase, H=512, W=512):
+ meshes = meshes.to(self.device)
+ cameras = cameras.to(self.device)
+ vertices = self.transform_vertices(meshes, cameras)
+ faces = meshes.faces_packed().to(torch.int32)
+ rast_out,_ = dr.rasterize(self._glctx, vertices, faces, resolution=(H, W), grad_db=False) #C,H,W,4
+ pix_to_face = rast_out[..., -1].to(torch.int32) - 1
+ return pix_to_face
+
+pix2faces_renderer = Pix2FacesRenderer()
+
+def get_visible_faces(meshes: Meshes, cameras: CamerasBase, resolution=1024):
+ # pix_to_face = render_pix2faces_py3d(meshes, cameras, H=resolution, W=resolution)['pix_to_face']
+ pix_to_face = pix2faces_renderer.render_pix2faces_nvdiff(meshes, cameras, H=resolution, W=resolution)
+
+ unique_faces = torch.unique(pix_to_face.flatten())
+ unique_faces = unique_faces[unique_faces != -1]
+ return unique_faces
+
+def project_color(meshes: Meshes, cameras: CamerasBase, pil_image: Image.Image, use_alpha=True, eps=0.05, resolution=1024, device="cuda") -> dict:
+ """
+ Projects color from a given image onto a 3D mesh.
+
+ Args:
+ meshes (pytorch3d.structures.Meshes): The 3D mesh object.
+ cameras (pytorch3d.renderer.cameras.CamerasBase): The camera object.
+ pil_image (PIL.Image.Image): The input image.
+ use_alpha (bool, optional): Whether to use the alpha channel of the image. Defaults to True.
+ eps (float, optional): The threshold for selecting visible faces. Defaults to 0.05.
+ resolution (int, optional): The resolution of the projection. Defaults to 1024.
+ device (str, optional): The device to use for computation. Defaults to "cuda".
+ debug (bool, optional): Whether to save debug images. Defaults to False.
+
+ Returns:
+ dict: A dictionary containing the following keys:
+ - "new_texture" (TexturesVertex): The updated texture with interpolated colors.
+ - "valid_verts" (Tensor of [M,3]): The indices of the vertices being projected.
+ - "valid_colors" (Tensor of [M,3]): The interpolated colors for the valid vertices.
+ """
+ meshes = meshes.to(device)
+ cameras = cameras.to(device)
+ image = torch.from_numpy(np.array(pil_image.convert("RGBA")) / 255.).permute((2, 0, 1)).float().to(device) # in CHW format of [0, 1.]
+ unique_faces = get_visible_faces(meshes, cameras, resolution=resolution)
+
+ # visible faces
+ faces_normals = meshes.faces_normals_packed()[unique_faces]
+ faces_normals = faces_normals / faces_normals.norm(dim=1, keepdim=True)
+ world_points = cameras.unproject_points(torch.tensor([[[0., 0., 0.1], [0., 0., 0.2]]]).to(device))[0]
+ view_direction = world_points[1] - world_points[0]
+ view_direction = view_direction / view_direction.norm(dim=0, keepdim=True)
+
+ # find invalid faces
+ cos_angles = (faces_normals * view_direction).sum(dim=1)
+ assert cos_angles.mean() < 0, f"The view direction is not correct. cos_angles.mean()={cos_angles.mean()}"
+ selected_faces = unique_faces[cos_angles < -eps]
+
+ # find verts
+ faces = meshes.faces_packed()[selected_faces] # [N, 3]
+ verts = torch.unique(faces.flatten()) # [N, 1]
+ verts_coordinates = meshes.verts_packed()[verts] # [N, 3]
+
+ # compute color
+ pt_tensor = cameras.transform_points(verts_coordinates)[..., :2] # NDC space points
+ valid = ~((pt_tensor.isnan()|(pt_tensor<-1)|(1 dict:
+ """
+ meshes: the mesh with vertex color to be completed.
+ valid_index: the index of the valid vertices, where valid means colors are fixed. [V, 1]
+ """
+ valid_index = valid_index.to(meshes.device)
+ colors = meshes.textures.verts_features_packed() # [V, 3]
+ V = colors.shape[0]
+
+ invalid_index = torch.ones_like(colors[:, 0]).bool() # [V]
+ invalid_index[valid_index] = False
+ invalid_index = torch.arange(V).to(meshes.device)[invalid_index]
+
+ L = meshes.laplacian_packed()
+ E = torch.sparse_coo_tensor(torch.tensor([list(range(V))] * 2), torch.ones((V,)), size=(V, V)).to(meshes.device)
+ L = L + E
+ # E = torch.eye(V, layout=torch.sparse_coo, device=meshes.device)
+ # L = L + E
+ colored_count = torch.ones_like(colors[:, 0]) # [V]
+ colored_count[invalid_index] = 0
+ L_invalid = torch.index_select(L, 0, invalid_index) # sparse [IV, V]
+
+ total_colored = colored_count.sum()
+ coloring_round = 0
+ stage = "uncolored"
+ from tqdm import tqdm
+ pbar = tqdm(miniters=100)
+ while stage == "uncolored" or coloring_round > 0:
+ new_color = torch.matmul(L_invalid, colors * colored_count[:, None]) # [IV, 3]
+ new_count = torch.matmul(L_invalid, colored_count)[:, None] # [IV, 1]
+ colors[invalid_index] = torch.where(new_count > 0, new_color / new_count, colors[invalid_index])
+ colored_count[invalid_index] = (new_count[:, 0] > 0).float()
+
+ new_total_colored = colored_count.sum()
+ if new_total_colored > total_colored:
+ total_colored = new_total_colored
+ coloring_round += 1
+ else:
+ stage = "colored"
+ coloring_round -= 1
+ pbar.update(1)
+ if coloring_round > 10000:
+ print("coloring_round > 10000, break")
+ break
+ assert not torch.isnan(colors).any()
+ meshes.textures = TexturesVertex(verts_features=[colors])
+ return meshes
+
+def multiview_color_projection(meshes: Meshes, image_list: List[Image.Image], cameras_list: List[CamerasBase]=None, camera_focal: float = 2 / 1.35, weights=None, eps=0.05, resolution=1024, device="cuda", reweight_with_cosangle="square", use_alpha=True, confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy="smooth") -> Meshes:
+ """
+ Projects color from a given image onto a 3D mesh.
+
+ Args:
+ meshes (pytorch3d.structures.Meshes): The 3D mesh object, only one mesh.
+ image_list (PIL.Image.Image): List of images.
+ cameras_list (list): List of cameras.
+ camera_focal (float, optional): The focal length of the camera, if cameras_list is not passed. Defaults to 2 / 1.35.
+ weights (list, optional): List of weights for each image, for ['front', 'front_right', 'right', 'back', 'left', 'front_left']. Defaults to None.
+ eps (float, optional): The threshold for selecting visible faces. Defaults to 0.05.
+ resolution (int, optional): The resolution of the projection. Defaults to 1024.
+ device (str, optional): The device to use for computation. Defaults to "cuda".
+ reweight_with_cosangle (str, optional): Whether to reweight the color with the angle between the view direction and the vertex normal. Defaults to None.
+ use_alpha (bool, optional): Whether to use the alpha channel of the image. Defaults to True.
+ confidence_threshold (float, optional): The threshold for the confidence of the projected color, if final projection weight is less than this, we will use the original color. Defaults to 0.1.
+ complete_unseen (bool, optional): Whether to complete the unseen vertex color using laplacian. Defaults to False.
+
+ Returns:
+ Meshes: the colored mesh
+ """
+ # 1. preprocess inputs
+ if image_list is None:
+ raise ValueError("image_list is None")
+ if cameras_list is None:
+ if len(image_list) == 8:
+ cameras_list = get_8view_cameras(device, focal=camera_focal)
+ elif len(image_list) == 6:
+ cameras_list = get_6view_cameras(device, focal=camera_focal)
+ elif len(image_list) == 4:
+ cameras_list = get_4view_cameras(device, focal=camera_focal)
+ elif len(image_list) == 2:
+ cameras_list = get_2view_cameras(device, focal=camera_focal)
+ else:
+ raise ValueError("cameras_list is None, and can not be guessed from image_list")
+ if weights is None:
+ if len(image_list) == 8:
+ weights = [2.0, 0.05, 0.2, 0.02, 1.0, 0.02, 0.2, 0.05]
+ elif len(image_list) == 6:
+ weights = [2.0, 0.05, 0.2, 1.0, 0.2, 0.05]
+ elif len(image_list) == 4:
+ weights = [2.0, 0.2, 1.0, 0.2]
+ elif len(image_list) == 2:
+ weights = [1.0, 1.0]
+ else:
+ raise ValueError("weights is None, and can not be guessed from image_list")
+
+ # 2. run projection
+ meshes = meshes.clone().to(device)
+ if weights is None:
+ weights = [1. for _ in range(len(cameras_list))]
+ assert len(cameras_list) == len(image_list) == len(weights)
+ original_color = meshes.textures.verts_features_packed()
+ assert not torch.isnan(original_color).any()
+ texture_counts = torch.zeros_like(original_color[..., :1])
+ texture_values = torch.zeros_like(original_color)
+ max_texture_counts = torch.zeros_like(original_color[..., :1])
+ max_texture_values = torch.zeros_like(original_color)
+ for camera, image, weight in zip(cameras_list, image_list, weights):
+ ret = project_color(meshes, camera, image, eps=eps, resolution=resolution, device=device, use_alpha=use_alpha)
+ if reweight_with_cosangle == "linear":
+ weight = (ret['cos_angles'].abs() * weight)[:, None]
+ elif reweight_with_cosangle == "square":
+ weight = (ret['cos_angles'].abs() ** 2 * weight)[:, None]
+ if use_alpha:
+ weight = weight * ret['valid_alpha']
+ assert weight.min() > -0.0001
+ texture_counts[ret['valid_verts']] += weight
+ texture_values[ret['valid_verts']] += ret['valid_colors'] * weight
+ max_texture_values[ret['valid_verts']] = torch.where(weight > max_texture_counts[ret['valid_verts']], ret['valid_colors'], max_texture_values[ret['valid_verts']])
+ max_texture_counts[ret['valid_verts']] = torch.max(max_texture_counts[ret['valid_verts']], weight)
+
+ # Method2
+ texture_values = torch.where(texture_counts > confidence_threshold, texture_values / texture_counts, texture_values)
+ if below_confidence_strategy == "smooth":
+ texture_values = torch.where(texture_counts <= confidence_threshold, (original_color * (confidence_threshold - texture_counts) + texture_values) / confidence_threshold, texture_values)
+ elif below_confidence_strategy == "original":
+ texture_values = torch.where(texture_counts <= confidence_threshold, original_color, texture_values)
+ else:
+ raise ValueError(f"below_confidence_strategy={below_confidence_strategy} is not supported")
+ assert not torch.isnan(texture_values).any()
+ meshes.textures = TexturesVertex(verts_features=[texture_values])
+
+ if complete_unseen:
+ meshes = complete_unseen_vertex_color(meshes, torch.arange(texture_values.shape[0]).to(device)[texture_counts[:, 0] >= confidence_threshold])
+ ret_mesh = meshes.detach()
+ del meshes
+ return ret_mesh
+
+def get_cameras_list(azim_list, device, focal=2/1.35, dist=1.1):
+ ret = []
+ for azim in azim_list:
+ R, T = look_at_view_transform(dist, 0, azim)
+ w2c = torch.cat([R[0].T, T[0, :, None]], dim=1)
+ cameras: OrthographicCameras = get_camera(w2c, focal_length=focal, cam_type='orthogonal').to(device)
+ ret.append(cameras)
+ return ret
+
+def get_8view_cameras(device, focal=2/1.35):
+ return get_cameras_list(azim_list = [180, 225, 270, 315, 0, 45, 90, 135], device=device, focal=focal)
+
+def get_6view_cameras(device, focal=2/1.35):
+ return get_cameras_list(azim_list = [180, 225, 270, 0, 90, 135], device=device, focal=focal)
+
+def get_4view_cameras(device, focal=2/1.35):
+ return get_cameras_list(azim_list = [180, 270, 0, 90], device=device, focal=focal)
+
+def get_2view_cameras(device, focal=2/1.35):
+ return get_cameras_list(azim_list = [180, 0], device=device, focal=focal)
+
+def get_multiple_view_cameras(device, focal=2/1.35, offset=180, num_views=8, dist=1.1):
+ return get_cameras_list(azim_list = (np.linspace(0, 360, num_views+1)[:-1] + offset) % 360, device=device, focal=focal, dist=dist)
+
+def align_with_alpha_bbox(source_img, target_img, final_size=1024):
+ # align source_img with target_img using alpha channel
+ # source_img and target_img are PIL.Image.Image
+ source_img = source_img.convert("RGBA")
+ target_img = target_img.convert("RGBA").resize((final_size, final_size))
+ source_np = np.array(source_img)
+ target_np = np.array(target_img)
+ source_alpha = source_np[:, :, 3]
+ target_alpha = target_np[:, :, 3]
+ bbox_source_min, bbox_source_max = np.argwhere(source_alpha > 0).min(axis=0), np.argwhere(source_alpha > 0).max(axis=0)
+ bbox_target_min, bbox_target_max = np.argwhere(target_alpha > 0).min(axis=0), np.argwhere(target_alpha > 0).max(axis=0)
+ source_content = source_np[bbox_source_min[0]:bbox_source_max[0]+1, bbox_source_min[1]:bbox_source_max[1]+1, :]
+ # resize source_content to fit in the position of target_content
+ source_content = Image.fromarray(source_content).resize((bbox_target_max[1]-bbox_target_min[1]+1, bbox_target_max[0]-bbox_target_min[0]+1), resample=Image.BICUBIC)
+ target_np[bbox_target_min[0]:bbox_target_max[0]+1, bbox_target_min[1]:bbox_target_max[1]+1, :] = np.array(source_content)
+ return Image.fromarray(target_np)
+
+def load_image_list_from_mvdiffusion(mvdiffusion_path, front_from_pil_or_path=None):
+ import os
+ image_list = []
+ for dir in ['front', 'front_right', 'right', 'back', 'left', 'front_left']:
+ image_path = os.path.join(mvdiffusion_path, f"rgb_000_{dir}.png")
+ pil = Image.open(image_path)
+ if dir == 'front':
+ if front_from_pil_or_path is not None:
+ if isinstance(front_from_pil_or_path, str):
+ replace_pil = Image.open(front_from_pil_or_path)
+ else:
+ replace_pil = front_from_pil_or_path
+ # align replace_pil with pil using bounding box in alpha channel
+ pil = align_with_alpha_bbox(replace_pil, pil, final_size=1024)
+ image_list.append(pil)
+ return image_list
+
+def load_image_list_from_img_grid(img_grid_path, resolution = 1024):
+ img_list = []
+ grid = Image.open(img_grid_path)
+ w, h = grid.size
+ for row in range(0, h, resolution):
+ for col in range(0, w, resolution):
+ img_list.append(grid.crop((col, row, col + resolution, row + resolution)))
+ return img_list
\ No newline at end of file
diff --git a/scripts/refine_lr_to_sr.py b/scripts/refine_lr_to_sr.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9a50cab8cd6d722ce0bf47ceb42ba168070fd9d
--- /dev/null
+++ b/scripts/refine_lr_to_sr.py
@@ -0,0 +1,60 @@
+import torch
+import os
+
+import numpy as np
+from hashlib import md5
+def hash_img(img):
+ return md5(np.array(img).tobytes()).hexdigest()
+def hash_any(obj):
+ return md5(str(obj).encode()).hexdigest()
+
+def refine_lr_with_sd(pil_image_list, concept_img_list, control_image_list, prompt_list, pipe=None, strength=0.35, neg_prompt_list="", output_size=(512, 512), controlnet_conditioning_scale=1.):
+ with torch.no_grad():
+ images = pipe(
+ image=pil_image_list,
+ ip_adapter_image=concept_img_list,
+ prompt=prompt_list,
+ neg_prompt=neg_prompt_list,
+ num_inference_steps=50,
+ strength=strength,
+ height=output_size[0],
+ width=output_size[1],
+ control_image=control_image_list,
+ guidance_scale=5.0,
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
+ generator=torch.manual_seed(233),
+ ).images
+ return images
+
+SR_cache = None
+
+def run_sr_fast(source_pils, scale=4):
+ from PIL import Image
+ from scripts.upsampler import RealESRGANer
+ import numpy as np
+ global SR_cache
+ if SR_cache is not None:
+ upsampler = SR_cache
+ else:
+ upsampler = RealESRGANer(
+ scale=4,
+ onnx_path="ckpt/realesrgan-x4.onnx",
+ tile=0,
+ tile_pad=10,
+ pre_pad=0,
+ half=True,
+ gpu_id=0,
+ )
+ ret_pils = []
+ for idx, img_pils in enumerate(source_pils):
+ np_in = isinstance(img_pils, np.ndarray)
+ assert isinstance(img_pils, (Image.Image, np.ndarray))
+ img = np.array(img_pils)
+ output, _ = upsampler.enhance(img, outscale=scale)
+ if np_in:
+ ret_pils.append(output)
+ else:
+ ret_pils.append(Image.fromarray(output))
+ if SR_cache is None:
+ SR_cache = upsampler
+ return ret_pils
diff --git a/scripts/sd_model_zoo.py b/scripts/sd_model_zoo.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e5e271ca5b8f242998e151482b86984aac1c10d
--- /dev/null
+++ b/scripts/sd_model_zoo.py
@@ -0,0 +1,131 @@
+from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline
+from transformers import CLIPVisionModelWithProjection
+import torch
+from copy import deepcopy
+
+ENABLE_CPU_CACHE = False
+DEFAULT_BASE_MODEL = "runwayml/stable-diffusion-v1-5"
+
+cached_models = {} # cache for models to avoid repeated loading, key is model name
+def cache_model(func):
+ def wrapper(*args, **kwargs):
+ if ENABLE_CPU_CACHE:
+ model_name = func.__name__ + str(args) + str(kwargs)
+ if model_name not in cached_models:
+ cached_models[model_name] = func(*args, **kwargs)
+ return cached_models[model_name]
+ else:
+ return func(*args, **kwargs)
+ return wrapper
+
+def copied_cache_model(func):
+ def wrapper(*args, **kwargs):
+ if ENABLE_CPU_CACHE:
+ model_name = func.__name__ + str(args) + str(kwargs)
+ if model_name not in cached_models:
+ cached_models[model_name] = func(*args, **kwargs)
+ return deepcopy(cached_models[model_name])
+ else:
+ return func(*args, **kwargs)
+ return wrapper
+
+def model_from_ckpt_or_pretrained(ckpt_or_pretrained, model_cls, original_config_file='ckpt/v1-inference.yaml', torch_dtype=torch.float16, **kwargs):
+ if ckpt_or_pretrained.endswith(".safetensors"):
+ pipe = model_cls.from_single_file(ckpt_or_pretrained, original_config_file=original_config_file, torch_dtype=torch_dtype, **kwargs)
+ else:
+ pipe = model_cls.from_pretrained(ckpt_or_pretrained, torch_dtype=torch_dtype, **kwargs)
+ return pipe
+
+@copied_cache_model
+def load_base_model_components(base_model=DEFAULT_BASE_MODEL, torch_dtype=torch.float16):
+ model_kwargs = dict(
+ torch_dtype=torch_dtype,
+ requires_safety_checker=False,
+ safety_checker=None,
+ )
+ pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained(
+ base_model,
+ StableDiffusionPipeline,
+ **model_kwargs
+ )
+ pipe.to("cpu")
+ return pipe.components
+
+@cache_model
+def load_controlnet(controlnet_path, torch_dtype=torch.float16):
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch_dtype)
+ return controlnet
+
+@cache_model
+def load_image_encoder():
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ "h94/IP-Adapter",
+ subfolder="models/image_encoder",
+ torch_dtype=torch.float16,
+ )
+ return image_encoder
+
+def load_common_sd15_pipe(base_model=DEFAULT_BASE_MODEL, device="auto", controlnet=None, ip_adapter=False, plus_model=True, torch_dtype=torch.float16, model_cpu_offload_seq=None, enable_sequential_cpu_offload=False, vae_slicing=False, pipeline_class=None, **kwargs):
+ model_kwargs = dict(
+ torch_dtype=torch_dtype,
+ device_map=device,
+ requires_safety_checker=False,
+ safety_checker=None,
+ )
+ components = load_base_model_components(base_model=base_model, torch_dtype=torch_dtype)
+ model_kwargs.update(components)
+ model_kwargs.update(kwargs)
+
+ if controlnet is not None:
+ if isinstance(controlnet, list):
+ controlnet = [load_controlnet(controlnet_path, torch_dtype=torch_dtype) for controlnet_path in controlnet]
+ else:
+ controlnet = load_controlnet(controlnet, torch_dtype=torch_dtype)
+ model_kwargs.update(controlnet=controlnet)
+
+ if pipeline_class is None:
+ if controlnet is not None:
+ pipeline_class = StableDiffusionControlNetPipeline
+ else:
+ pipeline_class = StableDiffusionPipeline
+
+ pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained(
+ base_model,
+ pipeline_class,
+ **model_kwargs
+ )
+
+ if ip_adapter:
+ image_encoder = load_image_encoder()
+ pipe.image_encoder = image_encoder
+ if plus_model:
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.safetensors")
+ else:
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.safetensors")
+ pipe.set_ip_adapter_scale(1.0)
+ else:
+ pipe.unload_ip_adapter()
+
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
+
+ if model_cpu_offload_seq is None:
+ if isinstance(pipe, StableDiffusionControlNetPipeline):
+ pipe.model_cpu_offload_seq = "text_encoder->controlnet->unet->vae"
+ elif isinstance(pipe, StableDiffusionControlNetImg2ImgPipeline):
+ pipe.model_cpu_offload_seq = "text_encoder->controlnet->vae->unet->vae"
+ else:
+ pipe.model_cpu_offload_seq = model_cpu_offload_seq
+
+ if enable_sequential_cpu_offload:
+ pipe.enable_sequential_cpu_offload()
+ else:
+ pipe = pipe.to("cuda")
+ pass
+ # pipe.enable_model_cpu_offload()
+ if vae_slicing:
+ pipe.enable_vae_slicing()
+
+ import gc
+ gc.collect()
+ return pipe
+
diff --git a/scripts/upsampler.py b/scripts/upsampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f4999ab864c9eb0282832fb1ad02b63e6014926
--- /dev/null
+++ b/scripts/upsampler.py
@@ -0,0 +1,229 @@
+import cv2
+import math
+import numpy as np
+import os
+import torch
+from torch.nn import functional as F
+from scripts.load_onnx import load_onnx_caller
+ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+
+class RealESRGANer():
+ """A helper class for upsampling images with RealESRGAN.
+
+ Args:
+ scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
+ model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
+ model (nn.Module): The defined network. Default: None.
+ tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
+ input images into tiles, and then process each of them. Finally, they will be merged into one image.
+ 0 denotes for do not use tile. Default: 0.
+ tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
+ pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
+ half (float): Whether to use half precision during inference. Default: False.
+ """
+
+ def __init__(self,
+ scale,
+ onnx_path,
+ tile=0,
+ tile_pad=10,
+ pre_pad=10,
+ half=False,
+ device=None,
+ gpu_id=None):
+ self.scale = scale
+ self.tile_size = tile
+ self.tile_pad = tile_pad
+ self.pre_pad = pre_pad
+ self.mod_scale = None
+ self.half = half
+
+ # initialize model
+ if gpu_id:
+ self.device = torch.device(
+ f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
+ else:
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
+ self.model = load_onnx_caller(onnx_path, single_output=True)
+ # warm up
+ sample_input = torch.randn(1,3,512,512).cuda().float()
+ self.model(sample_input)
+
+ def pre_process(self, img):
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
+ """
+ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
+ self.img = img.unsqueeze(0).to(self.device)
+ if self.half:
+ self.img = self.img.half()
+
+ # pre_pad
+ if self.pre_pad != 0:
+ self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
+ # mod pad for divisible borders
+ if self.scale == 2:
+ self.mod_scale = 2
+ elif self.scale == 1:
+ self.mod_scale = 4
+ if self.mod_scale is not None:
+ self.mod_pad_h, self.mod_pad_w = 0, 0
+ _, _, h, w = self.img.size()
+ if (h % self.mod_scale != 0):
+ self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
+ if (w % self.mod_scale != 0):
+ self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
+ self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
+
+ def process(self):
+ # model inference
+ self.output = self.model(self.img)
+
+ def tile_process(self):
+ """It will first crop input images to tiles, and then process each tile.
+ Finally, all the processed tiles are merged into one images.
+
+ Modified from: https://github.com/ata4/esrgan-launcher
+ """
+ batch, channel, height, width = self.img.shape
+ output_height = height * self.scale
+ output_width = width * self.scale
+ output_shape = (batch, channel, output_height, output_width)
+
+ # start with black image
+ self.output = self.img.new_zeros(output_shape)
+ tiles_x = math.ceil(width / self.tile_size)
+ tiles_y = math.ceil(height / self.tile_size)
+
+ # loop over all tiles
+ for y in range(tiles_y):
+ for x in range(tiles_x):
+ # extract tile from input image
+ ofs_x = x * self.tile_size
+ ofs_y = y * self.tile_size
+ # input tile area on total image
+ input_start_x = ofs_x
+ input_end_x = min(ofs_x + self.tile_size, width)
+ input_start_y = ofs_y
+ input_end_y = min(ofs_y + self.tile_size, height)
+
+ # input tile area on total image with padding
+ input_start_x_pad = max(input_start_x - self.tile_pad, 0)
+ input_end_x_pad = min(input_end_x + self.tile_pad, width)
+ input_start_y_pad = max(input_start_y - self.tile_pad, 0)
+ input_end_y_pad = min(input_end_y + self.tile_pad, height)
+
+ # input tile dimensions
+ input_tile_width = input_end_x - input_start_x
+ input_tile_height = input_end_y - input_start_y
+ tile_idx = y * tiles_x + x + 1
+ input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
+
+ # upscale tile
+ try:
+ with torch.no_grad():
+ output_tile = self.model(input_tile)
+ except RuntimeError as error:
+ print('Error', error)
+ print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
+
+ # output tile area on total image
+ output_start_x = input_start_x * self.scale
+ output_end_x = input_end_x * self.scale
+ output_start_y = input_start_y * self.scale
+ output_end_y = input_end_y * self.scale
+
+ # output tile area without padding
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
+
+ # put tile into output image
+ self.output[:, :, output_start_y:output_end_y,
+ output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
+ output_start_x_tile:output_end_x_tile]
+
+ def post_process(self):
+ # remove extra pad
+ if self.mod_scale is not None:
+ _, _, h, w = self.output.size()
+ self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
+ # remove prepad
+ if self.pre_pad != 0:
+ _, _, h, w = self.output.size()
+ self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
+ return self.output
+
+ @torch.no_grad()
+ def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
+ h_input, w_input = img.shape[0:2]
+ # img: numpy
+ img = img.astype(np.float32)
+ if np.max(img) > 256: # 16-bit image
+ max_range = 65535
+ print('\tInput is a 16-bit image')
+ else:
+ max_range = 255
+ img = img / max_range
+ if len(img.shape) == 2: # gray image
+ img_mode = 'L'
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+ elif img.shape[2] == 4: # RGBA image with alpha channel
+ img_mode = 'RGBA'
+ alpha = img[:, :, 3]
+ img = img[:, :, 0:3]
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ if alpha_upsampler == 'realesrgan':
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
+ else:
+ img_mode = 'RGB'
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ # ------------------- process image (without the alpha channel) ------------------- #
+ self.pre_process(img)
+ if self.tile_size > 0:
+ self.tile_process()
+ else:
+ self.process()
+ output_img = self.post_process()
+ output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
+ if img_mode == 'L':
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
+
+ # ------------------- process the alpha channel if necessary ------------------- #
+ if img_mode == 'RGBA':
+ if alpha_upsampler == 'realesrgan':
+ self.pre_process(alpha)
+ if self.tile_size > 0:
+ self.tile_process()
+ else:
+ self.process()
+ output_alpha = self.post_process()
+ output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
+ output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
+ else: # use the cv2 resize for alpha channel
+ h, w = alpha.shape[0:2]
+ output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
+
+ # merge the alpha channel
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
+ output_img[:, :, 3] = output_alpha
+
+ # ------------------------------ return ------------------------------ #
+ if max_range == 65535: # 16-bit image
+ output = (output_img * 65535.0).round().astype(np.uint16)
+ else:
+ output = (output_img * 255.0).round().astype(np.uint8)
+
+ if outscale is not None and outscale != float(self.scale):
+ output = cv2.resize(
+ output, (
+ int(w_input * outscale),
+ int(h_input * outscale),
+ ), interpolation=cv2.INTER_LANCZOS4)
+
+ return output, img_mode
+
diff --git a/scripts/utils.py b/scripts/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5270afc2476848cb89b23bfa3d85dea7bfdff83f
--- /dev/null
+++ b/scripts/utils.py
@@ -0,0 +1,319 @@
+import torch
+import numpy as np
+from PIL import Image
+import pymeshlab
+import pymeshlab as ml
+from pymeshlab import PercentageValue
+from pytorch3d.renderer import TexturesVertex
+from pytorch3d.structures import Meshes
+from rembg import new_session, remove
+import torch
+import torch.nn.functional as F
+from typing import List, Tuple
+from PIL import Image
+import trimesh
+
+providers = [
+ ('CUDAExecutionProvider', {
+ 'device_id': 0,
+ 'arena_extend_strategy': 'kSameAsRequested',
+ 'gpu_mem_limit': 8 * 1024 * 1024 * 1024,
+ 'cudnn_conv_algo_search': 'HEURISTIC',
+ })
+]
+
+session = new_session(providers=providers)
+
+NEG_PROMPT="sketch, sculpture, hand drawing, outline, single color, NSFW, lowres, bad anatomy,bad hands, text, error, missing fingers, yellow sleeves, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry,(worst quality:1.4),(low quality:1.4)"
+
+def load_mesh_with_trimesh(file_name, file_type=None):
+ import trimesh
+ mesh: trimesh.Trimesh = trimesh.load(file_name, file_type=file_type)
+ if isinstance(mesh, trimesh.Scene):
+ assert len(mesh.geometry) > 0
+ # save to obj first and load again to avoid offset issue
+ from io import BytesIO
+ with BytesIO() as f:
+ mesh.export(f, file_type="obj")
+ f.seek(0)
+ mesh = trimesh.load(f, file_type="obj")
+ if isinstance(mesh, trimesh.Scene):
+ # we lose texture information here
+ mesh = trimesh.util.concatenate(
+ tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
+ for g in mesh.geometry.values()))
+ assert isinstance(mesh, trimesh.Trimesh)
+
+ vertices = torch.from_numpy(mesh.vertices).T
+ faces = torch.from_numpy(mesh.faces).T
+ colors = None
+ if mesh.visual is not None:
+ if hasattr(mesh.visual, 'vertex_colors'):
+ colors = torch.from_numpy(mesh.visual.vertex_colors)[..., :3].T / 255.
+ if colors is None:
+ # print("Warning: no vertex color found in mesh! Filling it with gray.")
+ colors = torch.ones_like(vertices) * 0.5
+ return vertices, faces, colors
+
+def meshlab_mesh_to_py3dmesh(mesh: pymeshlab.Mesh) -> Meshes:
+ verts = torch.from_numpy(mesh.vertex_matrix()).float()
+ faces = torch.from_numpy(mesh.face_matrix()).long()
+ colors = torch.from_numpy(mesh.vertex_color_matrix()[..., :3]).float()
+ textures = TexturesVertex(verts_features=[colors])
+ return Meshes(verts=[verts], faces=[faces], textures=textures)
+
+
+def py3dmesh_to_meshlab_mesh(meshes: Meshes) -> pymeshlab.Mesh:
+ colors_in = F.pad(meshes.textures.verts_features_packed().cpu().float(), [0,1], value=1).numpy().astype(np.float64)
+ m1 = pymeshlab.Mesh(
+ vertex_matrix=meshes.verts_packed().cpu().float().numpy().astype(np.float64),
+ face_matrix=meshes.faces_packed().cpu().long().numpy().astype(np.int32),
+ v_normals_matrix=meshes.verts_normals_packed().cpu().float().numpy().astype(np.float64),
+ v_color_matrix=colors_in)
+ return m1
+
+
+def to_pyml_mesh(vertices,faces):
+ m1 = pymeshlab.Mesh(
+ vertex_matrix=vertices.cpu().float().numpy().astype(np.float64),
+ face_matrix=faces.cpu().long().numpy().astype(np.int32),
+ )
+ return m1
+
+
+def to_py3d_mesh(vertices, faces, normals=None):
+ from pytorch3d.structures import Meshes
+ from pytorch3d.renderer.mesh.textures import TexturesVertex
+ mesh = Meshes(verts=[vertices], faces=[faces], textures=None)
+ if normals is None:
+ normals = mesh.verts_normals_packed()
+ # set normals as vertext colors
+ mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5])
+ return mesh
+
+
+def from_py3d_mesh(mesh):
+ return mesh.verts_list()[0], mesh.faces_list()[0], mesh.textures.verts_features_packed()
+
+def rotate_normalmap_by_angle(normal_map: np.ndarray, angle: float):
+ """
+ rotate along y-axis
+ normal_map: np.array, shape=(H, W, 3) in [-1, 1]
+ angle: float, in degree
+ """
+ angle = angle / 180 * np.pi
+ R = np.array([[np.cos(angle), 0, np.sin(angle)], [0, 1, 0], [-np.sin(angle), 0, np.cos(angle)]])
+ return np.dot(normal_map.reshape(-1, 3), R.T).reshape(normal_map.shape)
+
+# from view coord to front view world coord
+def rotate_normals(normal_pils, return_types='np', rotate_direction=1) -> np.ndarray: # [0, 255]
+ n_views = len(normal_pils)
+ ret = []
+ for idx, rgba_normal in enumerate(normal_pils):
+ # rotate normal
+ normal_np = np.array(rgba_normal)[:, :, :3] / 255 # in [-1, 1]
+ alpha_np = np.array(rgba_normal)[:, :, 3] / 255 # in [0, 1]
+ normal_np = normal_np * 2 - 1
+ normal_np = rotate_normalmap_by_angle(normal_np, rotate_direction * idx * (360 / n_views))
+ normal_np = (normal_np + 1) / 2
+ normal_np = normal_np * alpha_np[..., None] # make bg black
+ rgba_normal_np = np.concatenate([normal_np * 255, alpha_np[:, :, None] * 255] , axis=-1)
+ if return_types == 'np':
+ ret.append(rgba_normal_np)
+ elif return_types == 'pil':
+ ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8)))
+ else:
+ raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}")
+ return ret
+
+
+def rotate_normalmap_by_angle_torch(normal_map, angle):
+ """
+ rotate along y-axis
+ normal_map: torch.Tensor, shape=(H, W, 3) in [-1, 1], device='cuda'
+ angle: float, in degree
+ """
+ angle = torch.tensor(angle / 180 * np.pi).to(normal_map)
+ R = torch.tensor([[torch.cos(angle), 0, torch.sin(angle)],
+ [0, 1, 0],
+ [-torch.sin(angle), 0, torch.cos(angle)]]).to(normal_map)
+ return torch.matmul(normal_map.view(-1, 3), R.T).view(normal_map.shape)
+
+def do_rotate(rgba_normal, angle):
+ rgba_normal = torch.from_numpy(rgba_normal).float().cuda() / 255
+ rotated_normal_tensor = rotate_normalmap_by_angle_torch(rgba_normal[..., :3] * 2 - 1, angle)
+ rotated_normal_tensor = (rotated_normal_tensor + 1) / 2
+ rotated_normal_tensor = rotated_normal_tensor * rgba_normal[:, :, [3]] # make bg black
+ rgba_normal_np = torch.cat([rotated_normal_tensor * 255, rgba_normal[:, :, [3]] * 255], dim=-1).cpu().numpy()
+ return rgba_normal_np
+
+def rotate_normals_torch(normal_pils, return_types='np', rotate_direction=1):
+ n_views = len(normal_pils)
+ ret = []
+ for idx, rgba_normal in enumerate(normal_pils):
+ # rotate normal
+ angle = rotate_direction * idx * (360 / n_views)
+ rgba_normal_np = do_rotate(np.array(rgba_normal), angle)
+ if return_types == 'np':
+ ret.append(rgba_normal_np)
+ elif return_types == 'pil':
+ ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8)))
+ else:
+ raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}")
+ return ret
+
+def change_bkgd(img_pils, new_bkgd=(0., 0., 0.)):
+ ret = []
+ new_bkgd = np.array(new_bkgd).reshape(1, 1, 3)
+ for rgba_img in img_pils:
+ img_np = np.array(rgba_img)[:, :, :3] / 255
+ alpha_np = np.array(rgba_img)[:, :, 3] / 255
+ ori_bkgd = img_np[:1, :1]
+ # color = ori_color * alpha + bkgd * (1-alpha)
+ # ori_color = (color - bkgd * (1-alpha)) / alpha
+ alpha_np_clamp = np.clip(alpha_np, 1e-6, 1) # avoid divide by zero
+ ori_img_np = (img_np - ori_bkgd * (1 - alpha_np[..., None])) / alpha_np_clamp[..., None]
+ img_np = np.where(alpha_np[..., None] > 0.05, ori_img_np * alpha_np[..., None] + new_bkgd * (1 - alpha_np[..., None]), new_bkgd)
+ rgba_img_np = np.concatenate([img_np * 255, alpha_np[..., None] * 255], axis=-1)
+ ret.append(Image.fromarray(rgba_img_np.astype(np.uint8)))
+ return ret
+
+def change_bkgd_to_normal(normal_pils) -> List[Image.Image]:
+ n_views = len(normal_pils)
+ ret = []
+ for idx, rgba_normal in enumerate(normal_pils):
+ # calcuate background normal
+ target_bkgd = rotate_normalmap_by_angle(np.array([[[0., 0., 1.]]]), idx * (360 / n_views))
+ normal_np = np.array(rgba_normal)[:, :, :3] / 255 # in [-1, 1]
+ alpha_np = np.array(rgba_normal)[:, :, 3] / 255 # in [0, 1]
+ normal_np = normal_np * 2 - 1
+ old_bkgd = normal_np[:1,:1]
+ normal_np[alpha_np > 0.05] = (normal_np[alpha_np > 0.05] - old_bkgd * (1 - alpha_np[alpha_np > 0.05][..., None])) / alpha_np[alpha_np > 0.05][..., None]
+ normal_np = normal_np * alpha_np[..., None] + target_bkgd * (1 - alpha_np[..., None])
+ normal_np = (normal_np + 1) / 2
+ rgba_normal_np = np.concatenate([normal_np * 255, alpha_np[..., None] * 255] , axis=-1)
+ ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8)))
+ return ret
+
+
+def fix_vert_color_glb(mesh_path):
+ from pygltflib import GLTF2, Material, PbrMetallicRoughness
+ obj1 = GLTF2().load(mesh_path)
+ obj1.meshes[0].primitives[0].material = 0
+ obj1.materials.append(Material(
+ pbrMetallicRoughness = PbrMetallicRoughness(
+ baseColorFactor = [1.0, 1.0, 1.0, 1.0],
+ metallicFactor = 0.,
+ roughnessFactor = 1.0,
+ ),
+ emissiveFactor = [0.0, 0.0, 0.0],
+ doubleSided = True,
+ ))
+ obj1.save(mesh_path)
+
+
+def srgb_to_linear(c_srgb):
+ c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4)
+ return c_linear.clip(0, 1.)
+
+
+def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
+ # convert from pytorch3d meshes to trimesh mesh
+ vertices = meshes.verts_packed().cpu().float().numpy()
+ triangles = meshes.faces_packed().cpu().long().numpy()
+ np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
+ if save_glb_path.endswith(".glb"):
+ # rotate 180 along +Y
+ vertices[:, [0, 2]] = -vertices[:, [0, 2]]
+
+ if apply_sRGB_to_LinearRGB:
+ np_color = srgb_to_linear(np_color)
+ assert vertices.shape[0] == np_color.shape[0]
+ assert np_color.shape[1] == 3
+ assert 0 <= np_color.min() and np_color.max() <= 1, f"min={np_color.min()}, max={np_color.max()}"
+ mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
+ mesh.remove_unreferenced_vertices()
+ # save mesh
+ mesh.export(save_glb_path)
+ if save_glb_path.endswith(".glb"):
+ fix_vert_color_glb(save_glb_path)
+ print(f"saving to {save_glb_path}")
+
+
+def save_glb_and_video(save_mesh_prefix: str, meshes: Meshes, with_timestamp=True, dist=3.5, azim_offset=180, resolution=512, fov_in_degrees=1 / 1.15, cam_type="ortho", view_padding=60, export_video=True) -> Tuple[str, str]:
+ import time
+ if '.' in save_mesh_prefix:
+ save_mesh_prefix = ".".join(save_mesh_prefix.split('.')[:-1])
+ if with_timestamp:
+ save_mesh_prefix = save_mesh_prefix + f"_{int(time.time())}"
+ ret_mesh = save_mesh_prefix + ".glb"
+ # optimizied version
+ save_py3dmesh_with_trimesh_fast(meshes, ret_mesh)
+ return ret_mesh, "novideo"
+
+
+def simple_clean_mesh(pyml_mesh: ml.Mesh, apply_smooth=True, stepsmoothnum=1, apply_sub_divide=False, sub_divide_threshold=0.25):
+ ms = ml.MeshSet()
+ ms.add_mesh(pyml_mesh, "cube_mesh")
+
+ if apply_smooth:
+ ms.apply_filter("apply_coord_laplacian_smoothing", stepsmoothnum=stepsmoothnum, cotangentweight=False)
+ if apply_sub_divide: # 5s, slow
+ ms.apply_filter("meshing_repair_non_manifold_vertices")
+ ms.apply_filter("meshing_repair_non_manifold_edges", method='Remove Faces')
+ ms.apply_filter("meshing_surface_subdivision_loop", iterations=2, threshold=PercentageValue(sub_divide_threshold))
+ return meshlab_mesh_to_py3dmesh(ms.current_mesh())
+
+
+def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+
+def simple_preprocess(input_image, rembg_session=session, background_color=255):
+ RES = 2048
+ input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
+ if input_image.mode != 'RGBA':
+ image_rem = input_image.convert('RGBA')
+ input_image = remove(image_rem, alpha_matting=False, session=rembg_session)
+
+ arr = np.asarray(input_image)
+ alpha = np.asarray(input_image)[:, :, -1]
+ x_nonzero = np.nonzero((alpha > 60).sum(axis=1))
+ y_nonzero = np.nonzero((alpha > 60).sum(axis=0))
+ x_min = int(x_nonzero[0].min())
+ y_min = int(y_nonzero[0].min())
+ x_max = int(x_nonzero[0].max())
+ y_max = int(y_nonzero[0].max())
+ arr = arr[x_min: x_max, y_min: y_max]
+ input_image = Image.fromarray(arr)
+ input_image = expand2square(input_image, (background_color, background_color, background_color, 0))
+ return input_image
+
+def init_target(img_pils, new_bkgd=(0., 0., 0.), device="cuda"):
+ # Convert the background color to a PyTorch tensor
+ new_bkgd = torch.tensor(new_bkgd, dtype=torch.float32).view(1, 1, 3).to(device)
+
+ # Convert all images to PyTorch tensors and process them
+ imgs = torch.stack([torch.from_numpy(np.array(img, dtype=np.float32)) for img in img_pils]).to(device) / 255
+ img_nps = imgs[..., :3]
+ alpha_nps = imgs[..., 3]
+ ori_bkgds = img_nps[:, :1, :1]
+
+ # Avoid divide by zero and calculate the original image
+ alpha_nps_clamp = torch.clamp(alpha_nps, 1e-6, 1)
+ ori_img_nps = (img_nps - ori_bkgds * (1 - alpha_nps.unsqueeze(-1))) / alpha_nps_clamp.unsqueeze(-1)
+ ori_img_nps = torch.clamp(ori_img_nps, 0, 1)
+ img_nps = torch.where(alpha_nps.unsqueeze(-1) > 0.05, ori_img_nps * alpha_nps.unsqueeze(-1) + new_bkgd * (1 - alpha_nps.unsqueeze(-1)), new_bkgd)
+
+ rgba_img_np = torch.cat([img_nps, alpha_nps.unsqueeze(-1)], dim=-1)
+ return rgba_img_np
\ No newline at end of file