Spaces:
Build error
Build error
Commit
•
f56be8c
0
Parent(s):
Duplicate from radames/PIFu-Clothed-Human-Digitization
Browse filesCo-authored-by: Radamés Ajna <radames@users.noreply.huggingface.co>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +28 -0
- .gitignore +47 -0
- PIFu/.gitignore +1 -0
- PIFu/LICENSE.txt +48 -0
- PIFu/README.md +167 -0
- PIFu/apps/__init__.py +0 -0
- PIFu/apps/crop_img.py +75 -0
- PIFu/apps/eval.py +153 -0
- PIFu/apps/eval_spaces.py +138 -0
- PIFu/apps/prt_util.py +142 -0
- PIFu/apps/render_data.py +290 -0
- PIFu/apps/train_color.py +191 -0
- PIFu/apps/train_shape.py +183 -0
- PIFu/env_sh.npy +0 -0
- PIFu/environment.yml +19 -0
- PIFu/lib/__init__.py +0 -0
- PIFu/lib/colab_util.py +114 -0
- PIFu/lib/data/BaseDataset.py +46 -0
- PIFu/lib/data/EvalDataset.py +166 -0
- PIFu/lib/data/TrainDataset.py +390 -0
- PIFu/lib/data/__init__.py +2 -0
- PIFu/lib/ext_transform.py +78 -0
- PIFu/lib/geometry.py +55 -0
- PIFu/lib/mesh_util.py +91 -0
- PIFu/lib/model/BasePIFuNet.py +76 -0
- PIFu/lib/model/ConvFilters.py +112 -0
- PIFu/lib/model/ConvPIFuNet.py +99 -0
- PIFu/lib/model/DepthNormalizer.py +18 -0
- PIFu/lib/model/HGFilters.py +146 -0
- PIFu/lib/model/HGPIFuNet.py +142 -0
- PIFu/lib/model/ResBlkPIFuNet.py +201 -0
- PIFu/lib/model/SurfaceClassifier.py +71 -0
- PIFu/lib/model/VhullPIFuNet.py +70 -0
- PIFu/lib/model/__init__.py +5 -0
- PIFu/lib/net_util.py +396 -0
- PIFu/lib/options.py +161 -0
- PIFu/lib/renderer/__init__.py +0 -0
- PIFu/lib/renderer/camera.py +207 -0
- PIFu/lib/renderer/gl/__init__.py +0 -0
- PIFu/lib/renderer/gl/cam_render.py +48 -0
- PIFu/lib/renderer/gl/data/prt.fs +153 -0
- PIFu/lib/renderer/gl/data/prt.vs +167 -0
- PIFu/lib/renderer/gl/data/prt_uv.fs +141 -0
- PIFu/lib/renderer/gl/data/prt_uv.vs +168 -0
- PIFu/lib/renderer/gl/data/quad.fs +11 -0
- PIFu/lib/renderer/gl/data/quad.vs +11 -0
- PIFu/lib/renderer/gl/framework.py +90 -0
- PIFu/lib/renderer/gl/glcontext.py +142 -0
- PIFu/lib/renderer/gl/init_gl.py +24 -0
- PIFu/lib/renderer/gl/prt_render.py +350 -0
.gitattributes
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
19 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.glb filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
results/
|
2 |
+
# Python build
|
3 |
+
.eggs/
|
4 |
+
gradio.egg-info/*
|
5 |
+
!gradio.egg-info/requires.txt
|
6 |
+
!gradio.egg-info/PKG-INFO
|
7 |
+
dist/
|
8 |
+
*.pyc
|
9 |
+
__pycache__/
|
10 |
+
*.py[cod]
|
11 |
+
*$py.class
|
12 |
+
build/
|
13 |
+
|
14 |
+
# JS build
|
15 |
+
gradio/templates/frontend
|
16 |
+
# Secrets
|
17 |
+
.env
|
18 |
+
|
19 |
+
# Gradio run artifacts
|
20 |
+
*.db
|
21 |
+
*.sqlite3
|
22 |
+
gradio/launches.json
|
23 |
+
flagged/
|
24 |
+
# gradio_cached_examples/
|
25 |
+
|
26 |
+
# Tests
|
27 |
+
.coverage
|
28 |
+
coverage.xml
|
29 |
+
test.txt
|
30 |
+
|
31 |
+
# Demos
|
32 |
+
demo/tmp.zip
|
33 |
+
demo/files/*.avi
|
34 |
+
demo/files/*.mp4
|
35 |
+
|
36 |
+
# Etc
|
37 |
+
.idea/*
|
38 |
+
.DS_Store
|
39 |
+
*.bak
|
40 |
+
workspace.code-workspace
|
41 |
+
*.h5
|
42 |
+
.vscode/
|
43 |
+
|
44 |
+
# log files
|
45 |
+
.pnpm-debug.log
|
46 |
+
venv/
|
47 |
+
*.db-journal
|
PIFu/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
checkpoints/*
|
PIFu/LICENSE.txt
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2019 Shunsuke Saito, Zeng Huang, and Ryota Natsume
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
22 |
+
|
23 |
+
anyabagomo
|
24 |
+
|
25 |
+
-------------------- LICENSE FOR ResBlk Image Encoder -----------------------
|
26 |
+
Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
|
27 |
+
All rights reserved.
|
28 |
+
|
29 |
+
Redistribution and use in source and binary forms, with or without
|
30 |
+
modification, are permitted provided that the following conditions are met:
|
31 |
+
|
32 |
+
* Redistributions of source code must retain the above copyright notice, this
|
33 |
+
list of conditions and the following disclaimer.
|
34 |
+
|
35 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
36 |
+
this list of conditions and the following disclaimer in the documentation
|
37 |
+
and/or other materials provided with the distribution.
|
38 |
+
|
39 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
40 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
41 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
42 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
43 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
44 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
45 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
46 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
47 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
48 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
PIFu/README.md
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PIFu: Pixel-Aligned Implicit Function for High-Resolution Clothed Human Digitization
|
2 |
+
|
3 |
+
[![report](https://img.shields.io/badge/arxiv-report-red)](https://arxiv.org/abs/1905.05172) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1GFSsqP2BWz4gtq0e-nki00ZHSirXwFyY)
|
4 |
+
|
5 |
+
News:
|
6 |
+
* \[2020/05/04\] Added EGL rendering option for training data generation. Now you can create your own training data with headless machines!
|
7 |
+
* \[2020/04/13\] Demo with Google Colab (incl. visualization) is available. Special thanks to [@nanopoteto](https://github.com/nanopoteto)!!!
|
8 |
+
* \[2020/02/26\] License is updated to MIT license! Enjoy!
|
9 |
+
|
10 |
+
This repository contains a pytorch implementation of "[PIFu: Pixel-Aligned Implicit Function for High-Resolution Clothed Human Digitization](https://arxiv.org/abs/1905.05172)".
|
11 |
+
|
12 |
+
[Project Page](https://shunsukesaito.github.io/PIFu/)
|
13 |
+
![Teaser Image](https://shunsukesaito.github.io/PIFu/resources/images/teaser.png)
|
14 |
+
|
15 |
+
If you find the code useful in your research, please consider citing the paper.
|
16 |
+
|
17 |
+
```
|
18 |
+
@InProceedings{saito2019pifu,
|
19 |
+
author = {Saito, Shunsuke and Huang, Zeng and Natsume, Ryota and Morishima, Shigeo and Kanazawa, Angjoo and Li, Hao},
|
20 |
+
title = {PIFu: Pixel-Aligned Implicit Function for High-Resolution Clothed Human Digitization},
|
21 |
+
booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
|
22 |
+
month = {October},
|
23 |
+
year = {2019}
|
24 |
+
}
|
25 |
+
```
|
26 |
+
|
27 |
+
|
28 |
+
This codebase provides:
|
29 |
+
- test code
|
30 |
+
- training code
|
31 |
+
- data generation code
|
32 |
+
|
33 |
+
## Requirements
|
34 |
+
- Python 3
|
35 |
+
- [PyTorch](https://pytorch.org/) tested on 1.4.0
|
36 |
+
- json
|
37 |
+
- PIL
|
38 |
+
- skimage
|
39 |
+
- tqdm
|
40 |
+
- numpy
|
41 |
+
- cv2
|
42 |
+
|
43 |
+
for training and data generation
|
44 |
+
- [trimesh](https://trimsh.org/) with [pyembree](https://github.com/scopatz/pyembree)
|
45 |
+
- [pyexr](https://github.com/tvogels/pyexr)
|
46 |
+
- PyOpenGL
|
47 |
+
- freeglut (use `sudo apt-get install freeglut3-dev` for ubuntu users)
|
48 |
+
- (optional) egl related packages for rendering with headless machines. (use `apt install libgl1-mesa-dri libegl1-mesa libgbm1` for ubuntu users)
|
49 |
+
|
50 |
+
Warning: I found that outdated NVIDIA drivers may cause errors with EGL. If you want to try out the EGL version, please update your NVIDIA driver to the latest!!
|
51 |
+
|
52 |
+
## Windows demo installation instuction
|
53 |
+
|
54 |
+
- Install [miniconda](https://docs.conda.io/en/latest/miniconda.html)
|
55 |
+
- Add `conda` to PATH
|
56 |
+
- Install [git bash](https://git-scm.com/downloads)
|
57 |
+
- Launch `Git\bin\bash.exe`
|
58 |
+
- `eval "$(conda shell.bash hook)"` then `conda activate my_env` because of [this](https://github.com/conda/conda-build/issues/3371)
|
59 |
+
- Automatic `env create -f environment.yml` (look [this](https://github.com/conda/conda/issues/3417))
|
60 |
+
- OR manually setup [environment](https://towardsdatascience.com/a-guide-to-conda-environments-bc6180fc533)
|
61 |
+
- `conda create —name pifu python` where `pifu` is name of your environment
|
62 |
+
- `conda activate`
|
63 |
+
- `conda install pytorch torchvision cudatoolkit=10.1 -c pytorch`
|
64 |
+
- `conda install pillow`
|
65 |
+
- `conda install scikit-image`
|
66 |
+
- `conda install tqdm`
|
67 |
+
- `conda install -c menpo opencv`
|
68 |
+
- Download [wget.exe](https://eternallybored.org/misc/wget/)
|
69 |
+
- Place it into `Git\mingw64\bin`
|
70 |
+
- `sh ./scripts/download_trained_model.sh`
|
71 |
+
- Remove background from your image ([this](https://www.remove.bg/), for example)
|
72 |
+
- Create black-white mask .png
|
73 |
+
- Replace original from sample_images/
|
74 |
+
- Try it out - `sh ./scripts/test.sh`
|
75 |
+
- Download [Meshlab](http://www.meshlab.net/) because of [this](https://github.com/shunsukesaito/PIFu/issues/1)
|
76 |
+
- Open .obj file in Meshlab
|
77 |
+
|
78 |
+
|
79 |
+
## Demo
|
80 |
+
Warning: The released model is trained with mostly upright standing scans with weak perspectie projection and the pitch angle of 0 degree. Reconstruction quality may degrade for images highly deviated from trainining data.
|
81 |
+
1. run the following script to download the pretrained models from the following link and copy them under `./PIFu/checkpoints/`.
|
82 |
+
```
|
83 |
+
sh ./scripts/download_trained_model.sh
|
84 |
+
```
|
85 |
+
|
86 |
+
2. run the following script. the script creates a textured `.obj` file under `./PIFu/eval_results/`. You may need to use `./apps/crop_img.py` to roughly align an input image and the corresponding mask to the training data for better performance. For background removal, you can use any off-the-shelf tools such as [removebg](https://www.remove.bg/).
|
87 |
+
```
|
88 |
+
sh ./scripts/test.sh
|
89 |
+
```
|
90 |
+
|
91 |
+
## Demo on Google Colab
|
92 |
+
If you do not have a setup to run PIFu, we offer Google Colab version to give it a try, allowing you to run PIFu in the cloud, free of charge. Try our Colab demo using the following notebook:
|
93 |
+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1GFSsqP2BWz4gtq0e-nki00ZHSirXwFyY)
|
94 |
+
|
95 |
+
## Data Generation (Linux Only)
|
96 |
+
While we are unable to release the full training data due to the restriction of commertial scans, we provide rendering code using free models in [RenderPeople](https://renderpeople.com/free-3d-people/).
|
97 |
+
This tutorial uses `rp_dennis_posed_004` model. Please download the model from [this link](https://renderpeople.com/sample/free/rp_dennis_posed_004_OBJ.zip) and unzip the content under a folder named `rp_dennis_posed_004_OBJ`. The same process can be applied to other RenderPeople data.
|
98 |
+
|
99 |
+
Warning: the following code becomes extremely slow without [pyembree](https://github.com/scopatz/pyembree). Please make sure you install pyembree.
|
100 |
+
|
101 |
+
1. run the following script to compute spherical harmonics coefficients for [precomputed radiance transfer (PRT)](https://sites.fas.harvard.edu/~cs278/papers/prt.pdf). In a nutshell, PRT is used to account for accurate light transport including ambient occlusion without compromising online rendering time, which significantly improves the photorealism compared with [a common sperical harmonics rendering using surface normals](https://cseweb.ucsd.edu/~ravir/papers/envmap/envmap.pdf). This process has to be done once for each obj file.
|
102 |
+
```
|
103 |
+
python -m apps.prt_util -i {path_to_rp_dennis_posed_004_OBJ}
|
104 |
+
```
|
105 |
+
|
106 |
+
2. run the following script. Under the specified data path, the code creates folders named `GEO`, `RENDER`, `MASK`, `PARAM`, `UV_RENDER`, `UV_MASK`, `UV_NORMAL`, and `UV_POS`. Note that you may need to list validation subjects to exclude from training in `{path_to_training_data}/val.txt` (this tutorial has only one subject and leave it empty). If you wish to render images with headless servers equipped with NVIDIA GPU, add -e to enable EGL rendering.
|
107 |
+
```
|
108 |
+
python -m apps.render_data -i {path_to_rp_dennis_posed_004_OBJ} -o {path_to_training_data} [-e]
|
109 |
+
```
|
110 |
+
|
111 |
+
## Training (Linux Only)
|
112 |
+
|
113 |
+
Warning: the following code becomes extremely slow without [pyembree](https://github.com/scopatz/pyembree). Please make sure you install pyembree.
|
114 |
+
|
115 |
+
1. run the following script to train the shape module. The intermediate results and checkpoints are saved under `./results` and `./checkpoints` respectively. You can add `--batch_size` and `--num_sample_input` flags to adjust the batch size and the number of sampled points based on available GPU memory.
|
116 |
+
```
|
117 |
+
python -m apps.train_shape --dataroot {path_to_training_data} --random_flip --random_scale --random_trans
|
118 |
+
```
|
119 |
+
|
120 |
+
2. run the following script to train the color module.
|
121 |
+
```
|
122 |
+
python -m apps.train_color --dataroot {path_to_training_data} --num_sample_inout 0 --num_sample_color 5000 --sigma 0.1 --random_flip --random_scale --random_trans
|
123 |
+
```
|
124 |
+
|
125 |
+
## Related Research
|
126 |
+
**[Monocular Real-Time Volumetric Performance Capture (ECCV 2020)](https://project-splinter.github.io/)**
|
127 |
+
*Ruilong Li\*, Yuliang Xiu\*, Shunsuke Saito, Zeng Huang, Kyle Olszewski, Hao Li*
|
128 |
+
|
129 |
+
The first real-time PIFu by accelerating reconstruction and rendering!!
|
130 |
+
|
131 |
+
**[PIFuHD: Multi-Level Pixel-Aligned Implicit Function for High-Resolution 3D Human Digitization (CVPR 2020)](https://shunsukesaito.github.io/PIFuHD/)**
|
132 |
+
*Shunsuke Saito, Tomas Simon, Jason Saragih, Hanbyul Joo*
|
133 |
+
|
134 |
+
We further improve the quality of reconstruction by leveraging multi-level approach!
|
135 |
+
|
136 |
+
**[ARCH: Animatable Reconstruction of Clothed Humans (CVPR 2020)](https://arxiv.org/pdf/2004.04572.pdf)**
|
137 |
+
*Zeng Huang, Yuanlu Xu, Christoph Lassner, Hao Li, Tony Tung*
|
138 |
+
|
139 |
+
Learning PIFu in canonical space for animatable avatar generation!
|
140 |
+
|
141 |
+
**[Robust 3D Self-portraits in Seconds (CVPR 2020)](http://www.liuyebin.com/portrait/portrait.html)**
|
142 |
+
*Zhe Li, Tao Yu, Chuanyu Pan, Zerong Zheng, Yebin Liu*
|
143 |
+
|
144 |
+
They extend PIFu to RGBD + introduce "PIFusion" utilizing PIFu reconstruction for non-rigid fusion.
|
145 |
+
|
146 |
+
**[Learning to Infer Implicit Surfaces without 3d Supervision (NeurIPS 2019)](http://papers.nips.cc/paper/9039-learning-to-infer-implicit-surfaces-without-3d-supervision.pdf)**
|
147 |
+
*Shichen Liu, Shunsuke Saito, Weikai Chen, Hao Li*
|
148 |
+
|
149 |
+
We answer to the question of "how can we learn implicit function if we don't have 3D ground truth?"
|
150 |
+
|
151 |
+
**[SiCloPe: Silhouette-Based Clothed People (CVPR 2019, best paper finalist)](https://arxiv.org/pdf/1901.00049.pdf)**
|
152 |
+
*Ryota Natsume\*, Shunsuke Saito\*, Zeng Huang, Weikai Chen, Chongyang Ma, Hao Li, Shigeo Morishima*
|
153 |
+
|
154 |
+
Our first attempt to reconstruct 3D clothed human body with texture from a single image!
|
155 |
+
|
156 |
+
**[Deep Volumetric Video from Very Sparse Multi-view Performance Capture (ECCV 2018)](http://openaccess.thecvf.com/content_ECCV_2018/papers/Zeng_Huang_Deep_Volumetric_Video_ECCV_2018_paper.pdf)**
|
157 |
+
*Zeng Huang, Tianye Li, Weikai Chen, Yajie Zhao, Jun Xing, Chloe LeGendre, Linjie Luo, Chongyang Ma, Hao Li*
|
158 |
+
|
159 |
+
Implict surface learning for sparse view human performance capture!
|
160 |
+
|
161 |
+
------
|
162 |
+
|
163 |
+
|
164 |
+
|
165 |
+
For commercial queries, please contact:
|
166 |
+
|
167 |
+
Hao Li: hao@hao-li.com ccto: saitos@usc.edu Baker!!
|
PIFu/apps/__init__.py
ADDED
File without changes
|
PIFu/apps/crop_img.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from pathlib import Path
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
def get_bbox(msk):
|
9 |
+
rows = np.any(msk, axis=1)
|
10 |
+
cols = np.any(msk, axis=0)
|
11 |
+
rmin, rmax = np.where(rows)[0][[0,-1]]
|
12 |
+
cmin, cmax = np.where(cols)[0][[0,-1]]
|
13 |
+
|
14 |
+
return rmin, rmax, cmin, cmax
|
15 |
+
|
16 |
+
def process_img(img, msk, bbox=None):
|
17 |
+
if bbox is None:
|
18 |
+
bbox = get_bbox(msk > 100)
|
19 |
+
cx = (bbox[3] + bbox[2])//2
|
20 |
+
cy = (bbox[1] + bbox[0])//2
|
21 |
+
|
22 |
+
w = img.shape[1]
|
23 |
+
h = img.shape[0]
|
24 |
+
height = int(1.138*(bbox[1] - bbox[0]))
|
25 |
+
hh = height//2
|
26 |
+
|
27 |
+
# crop
|
28 |
+
dw = min(cx, w-cx, hh)
|
29 |
+
if cy-hh < 0:
|
30 |
+
img = cv2.copyMakeBorder(img,hh-cy,0,0,0,cv2.BORDER_CONSTANT,value=[0,0,0])
|
31 |
+
msk = cv2.copyMakeBorder(msk,hh-cy,0,0,0,cv2.BORDER_CONSTANT,value=0)
|
32 |
+
cy = hh
|
33 |
+
if cy+hh > h:
|
34 |
+
img = cv2.copyMakeBorder(img,0,cy+hh-h,0,0,cv2.BORDER_CONSTANT,value=[0,0,0])
|
35 |
+
msk = cv2.copyMakeBorder(msk,0,cy+hh-h,0,0,cv2.BORDER_CONSTANT,value=0)
|
36 |
+
img = img[cy-hh:(cy+hh),cx-dw:cx+dw,:]
|
37 |
+
msk = msk[cy-hh:(cy+hh),cx-dw:cx+dw]
|
38 |
+
dw = img.shape[0] - img.shape[1]
|
39 |
+
if dw != 0:
|
40 |
+
img = cv2.copyMakeBorder(img,0,0,dw//2,dw//2,cv2.BORDER_CONSTANT,value=[0,0,0])
|
41 |
+
msk = cv2.copyMakeBorder(msk,0,0,dw//2,dw//2,cv2.BORDER_CONSTANT,value=0)
|
42 |
+
img = cv2.resize(img, (512, 512))
|
43 |
+
msk = cv2.resize(msk, (512, 512))
|
44 |
+
|
45 |
+
kernel = np.ones((3,3),np.uint8)
|
46 |
+
msk = cv2.erode((255*(msk > 100)).astype(np.uint8), kernel, iterations = 1)
|
47 |
+
|
48 |
+
return img, msk
|
49 |
+
|
50 |
+
def main():
|
51 |
+
'''
|
52 |
+
given foreground mask, this script crops and resizes an input image and mask for processing.
|
53 |
+
'''
|
54 |
+
parser = argparse.ArgumentParser()
|
55 |
+
parser.add_argument('-i', '--input_image', type=str, help='if the image has alpha channel, it will be used as mask')
|
56 |
+
parser.add_argument('-m', '--input_mask', type=str)
|
57 |
+
parser.add_argument('-o', '--out_path', type=str, default='./sample_images')
|
58 |
+
args = parser.parse_args()
|
59 |
+
|
60 |
+
img = cv2.imread(args.input_image, cv2.IMREAD_UNCHANGED)
|
61 |
+
if img.shape[2] == 4:
|
62 |
+
msk = img[:,:,3:]
|
63 |
+
img = img[:,:,:3]
|
64 |
+
else:
|
65 |
+
msk = cv2.imread(args.input_mask, cv2.IMREAD_GRAYSCALE)
|
66 |
+
|
67 |
+
img_new, msk_new = process_img(img, msk)
|
68 |
+
|
69 |
+
img_name = Path(args.input_image).stem
|
70 |
+
|
71 |
+
cv2.imwrite(os.path.join(args.out_path, img_name + '.png'), img_new)
|
72 |
+
cv2.imwrite(os.path.join(args.out_path, img_name + '_mask.png'), msk_new)
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
main()
|
PIFu/apps/eval.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tqdm
|
2 |
+
import glob
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
from PIL import Image
|
5 |
+
from lib.model import *
|
6 |
+
from lib.train_util import *
|
7 |
+
from lib.sample_util import *
|
8 |
+
from lib.mesh_util import *
|
9 |
+
# from lib.options import BaseOptions
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
import torch
|
12 |
+
import numpy as np
|
13 |
+
import json
|
14 |
+
import time
|
15 |
+
import sys
|
16 |
+
import os
|
17 |
+
|
18 |
+
sys.path.insert(0, os.path.abspath(
|
19 |
+
os.path.join(os.path.dirname(__file__), '..')))
|
20 |
+
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
21 |
+
|
22 |
+
|
23 |
+
# # get options
|
24 |
+
# opt = BaseOptions().parse()
|
25 |
+
|
26 |
+
class Evaluator:
|
27 |
+
def __init__(self, opt, projection_mode='orthogonal'):
|
28 |
+
self.opt = opt
|
29 |
+
self.load_size = self.opt.loadSize
|
30 |
+
self.to_tensor = transforms.Compose([
|
31 |
+
transforms.Resize(self.load_size),
|
32 |
+
transforms.ToTensor(),
|
33 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
34 |
+
])
|
35 |
+
# set cuda
|
36 |
+
cuda = torch.device(
|
37 |
+
'cuda:%d' % opt.gpu_id) if torch.cuda.is_available() else torch.device('cpu')
|
38 |
+
|
39 |
+
# create net
|
40 |
+
netG = HGPIFuNet(opt, projection_mode).to(device=cuda)
|
41 |
+
print('Using Network: ', netG.name)
|
42 |
+
|
43 |
+
if opt.load_netG_checkpoint_path:
|
44 |
+
netG.load_state_dict(torch.load(
|
45 |
+
opt.load_netG_checkpoint_path, map_location=cuda))
|
46 |
+
|
47 |
+
if opt.load_netC_checkpoint_path is not None:
|
48 |
+
print('loading for net C ...', opt.load_netC_checkpoint_path)
|
49 |
+
netC = ResBlkPIFuNet(opt).to(device=cuda)
|
50 |
+
netC.load_state_dict(torch.load(
|
51 |
+
opt.load_netC_checkpoint_path, map_location=cuda))
|
52 |
+
else:
|
53 |
+
netC = None
|
54 |
+
|
55 |
+
os.makedirs(opt.results_path, exist_ok=True)
|
56 |
+
os.makedirs('%s/%s' % (opt.results_path, opt.name), exist_ok=True)
|
57 |
+
|
58 |
+
opt_log = os.path.join(opt.results_path, opt.name, 'opt.txt')
|
59 |
+
with open(opt_log, 'w') as outfile:
|
60 |
+
outfile.write(json.dumps(vars(opt), indent=2))
|
61 |
+
|
62 |
+
self.cuda = cuda
|
63 |
+
self.netG = netG
|
64 |
+
self.netC = netC
|
65 |
+
|
66 |
+
def load_image(self, image_path, mask_path):
|
67 |
+
# Name
|
68 |
+
img_name = os.path.splitext(os.path.basename(image_path))[0]
|
69 |
+
# Calib
|
70 |
+
B_MIN = np.array([-1, -1, -1])
|
71 |
+
B_MAX = np.array([1, 1, 1])
|
72 |
+
projection_matrix = np.identity(4)
|
73 |
+
projection_matrix[1, 1] = -1
|
74 |
+
calib = torch.Tensor(projection_matrix).float()
|
75 |
+
# Mask
|
76 |
+
mask = Image.open(mask_path).convert('L')
|
77 |
+
mask = transforms.Resize(self.load_size)(mask)
|
78 |
+
mask = transforms.ToTensor()(mask).float()
|
79 |
+
# image
|
80 |
+
image = Image.open(image_path).convert('RGB')
|
81 |
+
image = self.to_tensor(image)
|
82 |
+
image = mask.expand_as(image) * image
|
83 |
+
return {
|
84 |
+
'name': img_name,
|
85 |
+
'img': image.unsqueeze(0),
|
86 |
+
'calib': calib.unsqueeze(0),
|
87 |
+
'mask': mask.unsqueeze(0),
|
88 |
+
'b_min': B_MIN,
|
89 |
+
'b_max': B_MAX,
|
90 |
+
}
|
91 |
+
|
92 |
+
def load_image_from_memory(self, image_path, mask_path, img_name):
|
93 |
+
# Calib
|
94 |
+
B_MIN = np.array([-1, -1, -1])
|
95 |
+
B_MAX = np.array([1, 1, 1])
|
96 |
+
projection_matrix = np.identity(4)
|
97 |
+
projection_matrix[1, 1] = -1
|
98 |
+
calib = torch.Tensor(projection_matrix).float()
|
99 |
+
# Mask
|
100 |
+
mask = Image.fromarray(mask_path).convert('L')
|
101 |
+
mask = transforms.Resize(self.load_size)(mask)
|
102 |
+
mask = transforms.ToTensor()(mask).float()
|
103 |
+
# image
|
104 |
+
image = Image.fromarray(image_path).convert('RGB')
|
105 |
+
image = self.to_tensor(image)
|
106 |
+
image = mask.expand_as(image) * image
|
107 |
+
return {
|
108 |
+
'name': img_name,
|
109 |
+
'img': image.unsqueeze(0),
|
110 |
+
'calib': calib.unsqueeze(0),
|
111 |
+
'mask': mask.unsqueeze(0),
|
112 |
+
'b_min': B_MIN,
|
113 |
+
'b_max': B_MAX,
|
114 |
+
}
|
115 |
+
|
116 |
+
def eval(self, data, use_octree=False):
|
117 |
+
'''
|
118 |
+
Evaluate a data point
|
119 |
+
:param data: a dict containing at least ['name'], ['image'], ['calib'], ['b_min'] and ['b_max'] tensors.
|
120 |
+
:return:
|
121 |
+
'''
|
122 |
+
opt = self.opt
|
123 |
+
with torch.no_grad():
|
124 |
+
self.netG.eval()
|
125 |
+
if self.netC:
|
126 |
+
self.netC.eval()
|
127 |
+
save_path = '%s/%s/result_%s.obj' % (
|
128 |
+
opt.results_path, opt.name, data['name'])
|
129 |
+
if self.netC:
|
130 |
+
gen_mesh_color(opt, self.netG, self.netC, self.cuda,
|
131 |
+
data, save_path, use_octree=use_octree)
|
132 |
+
else:
|
133 |
+
gen_mesh(opt, self.netG, self.cuda, data,
|
134 |
+
save_path, use_octree=use_octree)
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == '__main__':
|
138 |
+
evaluator = Evaluator(opt)
|
139 |
+
|
140 |
+
test_images = glob.glob(os.path.join(opt.test_folder_path, '*'))
|
141 |
+
test_images = [f for f in test_images if (
|
142 |
+
'png' in f or 'jpg' in f) and (not 'mask' in f)]
|
143 |
+
test_masks = [f[:-4]+'_mask.png' for f in test_images]
|
144 |
+
|
145 |
+
print("num; ", len(test_masks))
|
146 |
+
|
147 |
+
for image_path, mask_path in tqdm.tqdm(zip(test_images, test_masks)):
|
148 |
+
try:
|
149 |
+
print(image_path, mask_path)
|
150 |
+
data = evaluator.load_image(image_path, mask_path)
|
151 |
+
evaluator.eval(data, True)
|
152 |
+
except Exception as e:
|
153 |
+
print("error:", e.args)
|
PIFu/apps/eval_spaces.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
5 |
+
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
6 |
+
|
7 |
+
import time
|
8 |
+
import json
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
|
13 |
+
from lib.options import BaseOptions
|
14 |
+
from lib.mesh_util import *
|
15 |
+
from lib.sample_util import *
|
16 |
+
from lib.train_util import *
|
17 |
+
from lib.model import *
|
18 |
+
|
19 |
+
from PIL import Image
|
20 |
+
import torchvision.transforms as transforms
|
21 |
+
|
22 |
+
import trimesh
|
23 |
+
from datetime import datetime
|
24 |
+
|
25 |
+
# get options
|
26 |
+
opt = BaseOptions().parse()
|
27 |
+
|
28 |
+
class Evaluator:
|
29 |
+
def __init__(self, opt, projection_mode='orthogonal'):
|
30 |
+
self.opt = opt
|
31 |
+
self.load_size = self.opt.loadSize
|
32 |
+
self.to_tensor = transforms.Compose([
|
33 |
+
transforms.Resize(self.load_size),
|
34 |
+
transforms.ToTensor(),
|
35 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
36 |
+
])
|
37 |
+
# set cuda
|
38 |
+
cuda = torch.device('cuda:%d' % opt.gpu_id) if torch.cuda.is_available() else torch.device('cpu')
|
39 |
+
print("CUDDAAAAA ???", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "NO ONLY CPU")
|
40 |
+
|
41 |
+
# create net
|
42 |
+
netG = HGPIFuNet(opt, projection_mode).to(device=cuda)
|
43 |
+
print('Using Network: ', netG.name)
|
44 |
+
|
45 |
+
if opt.load_netG_checkpoint_path:
|
46 |
+
netG.load_state_dict(torch.load(opt.load_netG_checkpoint_path, map_location=cuda))
|
47 |
+
|
48 |
+
if opt.load_netC_checkpoint_path is not None:
|
49 |
+
print('loading for net C ...', opt.load_netC_checkpoint_path)
|
50 |
+
netC = ResBlkPIFuNet(opt).to(device=cuda)
|
51 |
+
netC.load_state_dict(torch.load(opt.load_netC_checkpoint_path, map_location=cuda))
|
52 |
+
else:
|
53 |
+
netC = None
|
54 |
+
|
55 |
+
os.makedirs(opt.results_path, exist_ok=True)
|
56 |
+
os.makedirs('%s/%s' % (opt.results_path, opt.name), exist_ok=True)
|
57 |
+
|
58 |
+
opt_log = os.path.join(opt.results_path, opt.name, 'opt.txt')
|
59 |
+
with open(opt_log, 'w') as outfile:
|
60 |
+
outfile.write(json.dumps(vars(opt), indent=2))
|
61 |
+
|
62 |
+
self.cuda = cuda
|
63 |
+
self.netG = netG
|
64 |
+
self.netC = netC
|
65 |
+
|
66 |
+
def load_image(self, image_path, mask_path):
|
67 |
+
# Name
|
68 |
+
img_name = os.path.splitext(os.path.basename(image_path))[0]
|
69 |
+
# Calib
|
70 |
+
B_MIN = np.array([-1, -1, -1])
|
71 |
+
B_MAX = np.array([1, 1, 1])
|
72 |
+
projection_matrix = np.identity(4)
|
73 |
+
projection_matrix[1, 1] = -1
|
74 |
+
calib = torch.Tensor(projection_matrix).float()
|
75 |
+
# Mask
|
76 |
+
mask = Image.open(mask_path).convert('L')
|
77 |
+
mask = transforms.Resize(self.load_size)(mask)
|
78 |
+
mask = transforms.ToTensor()(mask).float()
|
79 |
+
# image
|
80 |
+
image = Image.open(image_path).convert('RGB')
|
81 |
+
image = self.to_tensor(image)
|
82 |
+
image = mask.expand_as(image) * image
|
83 |
+
return {
|
84 |
+
'name': img_name,
|
85 |
+
'img': image.unsqueeze(0),
|
86 |
+
'calib': calib.unsqueeze(0),
|
87 |
+
'mask': mask.unsqueeze(0),
|
88 |
+
'b_min': B_MIN,
|
89 |
+
'b_max': B_MAX,
|
90 |
+
}
|
91 |
+
|
92 |
+
def eval(self, data, use_octree=False):
|
93 |
+
'''
|
94 |
+
Evaluate a data point
|
95 |
+
:param data: a dict containing at least ['name'], ['image'], ['calib'], ['b_min'] and ['b_max'] tensors.
|
96 |
+
:return:
|
97 |
+
'''
|
98 |
+
opt = self.opt
|
99 |
+
with torch.no_grad():
|
100 |
+
self.netG.eval()
|
101 |
+
if self.netC:
|
102 |
+
self.netC.eval()
|
103 |
+
save_path = '%s/%s/result_%s.obj' % (opt.results_path, opt.name, data['name'])
|
104 |
+
if self.netC:
|
105 |
+
gen_mesh_color(opt, self.netG, self.netC, self.cuda, data, save_path, use_octree=use_octree)
|
106 |
+
else:
|
107 |
+
gen_mesh(opt, self.netG, self.cuda, data, save_path, use_octree=use_octree)
|
108 |
+
|
109 |
+
|
110 |
+
if __name__ == '__main__':
|
111 |
+
evaluator = Evaluator(opt)
|
112 |
+
|
113 |
+
results_path = opt.results_path
|
114 |
+
name = opt.name
|
115 |
+
test_image_path = opt.img_path
|
116 |
+
test_mask_path = test_image_path[:-4] +'_mask.png'
|
117 |
+
test_img_name = os.path.splitext(os.path.basename(test_image_path))[0]
|
118 |
+
print("test_image: ", test_image_path)
|
119 |
+
print("test_mask: ", test_mask_path)
|
120 |
+
|
121 |
+
try:
|
122 |
+
time = datetime.now()
|
123 |
+
print("evaluating" , time)
|
124 |
+
data = evaluator.load_image(test_image_path, test_mask_path)
|
125 |
+
evaluator.eval(data, False)
|
126 |
+
print("done evaluating" , datetime.now() - time)
|
127 |
+
except Exception as e:
|
128 |
+
print("error:", e.args)
|
129 |
+
|
130 |
+
try:
|
131 |
+
mesh = trimesh.load(f'{results_path}/{name}/result_{test_img_name}.obj')
|
132 |
+
mesh.apply_transform([[1, 0, 0, 0],
|
133 |
+
[0, 1, 0, 0],
|
134 |
+
[0, 0, -1, 0],
|
135 |
+
[0, 0, 0, 1]])
|
136 |
+
mesh.export(file_obj=f'{results_path}/{name}/result_{test_img_name}.glb')
|
137 |
+
except Exception as e:
|
138 |
+
print("error generating MESH", e)
|
PIFu/apps/prt_util.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import trimesh
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
from scipy.special import sph_harm
|
6 |
+
import argparse
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
def factratio(N, D):
|
10 |
+
if N >= D:
|
11 |
+
prod = 1.0
|
12 |
+
for i in range(D+1, N+1):
|
13 |
+
prod *= i
|
14 |
+
return prod
|
15 |
+
else:
|
16 |
+
prod = 1.0
|
17 |
+
for i in range(N+1, D+1):
|
18 |
+
prod *= i
|
19 |
+
return 1.0 / prod
|
20 |
+
|
21 |
+
def KVal(M, L):
|
22 |
+
return math.sqrt(((2 * L + 1) / (4 * math.pi)) * (factratio(L - M, L + M)))
|
23 |
+
|
24 |
+
def AssociatedLegendre(M, L, x):
|
25 |
+
if M < 0 or M > L or np.max(np.abs(x)) > 1.0:
|
26 |
+
return np.zeros_like(x)
|
27 |
+
|
28 |
+
pmm = np.ones_like(x)
|
29 |
+
if M > 0:
|
30 |
+
somx2 = np.sqrt((1.0 + x) * (1.0 - x))
|
31 |
+
fact = 1.0
|
32 |
+
for i in range(1, M+1):
|
33 |
+
pmm = -pmm * fact * somx2
|
34 |
+
fact = fact + 2
|
35 |
+
|
36 |
+
if L == M:
|
37 |
+
return pmm
|
38 |
+
else:
|
39 |
+
pmmp1 = x * (2 * M + 1) * pmm
|
40 |
+
if L == M+1:
|
41 |
+
return pmmp1
|
42 |
+
else:
|
43 |
+
pll = np.zeros_like(x)
|
44 |
+
for i in range(M+2, L+1):
|
45 |
+
pll = (x * (2 * i - 1) * pmmp1 - (i + M - 1) * pmm) / (i - M)
|
46 |
+
pmm = pmmp1
|
47 |
+
pmmp1 = pll
|
48 |
+
return pll
|
49 |
+
|
50 |
+
def SphericalHarmonic(M, L, theta, phi):
|
51 |
+
if M > 0:
|
52 |
+
return math.sqrt(2.0) * KVal(M, L) * np.cos(M * phi) * AssociatedLegendre(M, L, np.cos(theta))
|
53 |
+
elif M < 0:
|
54 |
+
return math.sqrt(2.0) * KVal(-M, L) * np.sin(-M * phi) * AssociatedLegendre(-M, L, np.cos(theta))
|
55 |
+
else:
|
56 |
+
return KVal(0, L) * AssociatedLegendre(0, L, np.cos(theta))
|
57 |
+
|
58 |
+
def save_obj(mesh_path, verts):
|
59 |
+
file = open(mesh_path, 'w')
|
60 |
+
for v in verts:
|
61 |
+
file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
|
62 |
+
file.close()
|
63 |
+
|
64 |
+
def sampleSphericalDirections(n):
|
65 |
+
xv = np.random.rand(n,n)
|
66 |
+
yv = np.random.rand(n,n)
|
67 |
+
theta = np.arccos(1-2 * xv)
|
68 |
+
phi = 2.0 * math.pi * yv
|
69 |
+
|
70 |
+
phi = phi.reshape(-1)
|
71 |
+
theta = theta.reshape(-1)
|
72 |
+
|
73 |
+
vx = -np.sin(theta) * np.cos(phi)
|
74 |
+
vy = -np.sin(theta) * np.sin(phi)
|
75 |
+
vz = np.cos(theta)
|
76 |
+
return np.stack([vx, vy, vz], 1), phi, theta
|
77 |
+
|
78 |
+
def getSHCoeffs(order, phi, theta):
|
79 |
+
shs = []
|
80 |
+
for n in range(0, order+1):
|
81 |
+
for m in range(-n,n+1):
|
82 |
+
s = SphericalHarmonic(m, n, theta, phi)
|
83 |
+
shs.append(s)
|
84 |
+
|
85 |
+
return np.stack(shs, 1)
|
86 |
+
|
87 |
+
def computePRT(mesh_path, n, order):
|
88 |
+
mesh = trimesh.load(mesh_path, process=False)
|
89 |
+
vectors_orig, phi, theta = sampleSphericalDirections(n)
|
90 |
+
SH_orig = getSHCoeffs(order, phi, theta)
|
91 |
+
|
92 |
+
w = 4.0 * math.pi / (n*n)
|
93 |
+
|
94 |
+
origins = mesh.vertices
|
95 |
+
normals = mesh.vertex_normals
|
96 |
+
n_v = origins.shape[0]
|
97 |
+
|
98 |
+
origins = np.repeat(origins[:,None], n, axis=1).reshape(-1,3)
|
99 |
+
normals = np.repeat(normals[:,None], n, axis=1).reshape(-1,3)
|
100 |
+
PRT_all = None
|
101 |
+
for i in tqdm(range(n)):
|
102 |
+
SH = np.repeat(SH_orig[None,(i*n):((i+1)*n)], n_v, axis=0).reshape(-1,SH_orig.shape[1])
|
103 |
+
vectors = np.repeat(vectors_orig[None,(i*n):((i+1)*n)], n_v, axis=0).reshape(-1,3)
|
104 |
+
|
105 |
+
dots = (vectors * normals).sum(1)
|
106 |
+
front = (dots > 0.0)
|
107 |
+
|
108 |
+
delta = 1e-3*min(mesh.bounding_box.extents)
|
109 |
+
hits = mesh.ray.intersects_any(origins + delta * normals, vectors)
|
110 |
+
nohits = np.logical_and(front, np.logical_not(hits))
|
111 |
+
|
112 |
+
PRT = (nohits.astype(np.float) * dots)[:,None] * SH
|
113 |
+
|
114 |
+
if PRT_all is not None:
|
115 |
+
PRT_all += (PRT.reshape(-1, n, SH.shape[1]).sum(1))
|
116 |
+
else:
|
117 |
+
PRT_all = (PRT.reshape(-1, n, SH.shape[1]).sum(1))
|
118 |
+
|
119 |
+
PRT = w * PRT_all
|
120 |
+
|
121 |
+
# NOTE: trimesh sometimes break the original vertex order, but topology will not change.
|
122 |
+
# when loading PRT in other program, use the triangle list from trimesh.
|
123 |
+
return PRT, mesh.faces
|
124 |
+
|
125 |
+
def testPRT(dir_path, n=40):
|
126 |
+
if dir_path[-1] == '/':
|
127 |
+
dir_path = dir_path[:-1]
|
128 |
+
sub_name = dir_path.split('/')[-1][:-4]
|
129 |
+
obj_path = os.path.join(dir_path, sub_name + '_100k.obj')
|
130 |
+
os.makedirs(os.path.join(dir_path, 'bounce'), exist_ok=True)
|
131 |
+
|
132 |
+
PRT, F = computePRT(obj_path, n, 2)
|
133 |
+
np.savetxt(os.path.join(dir_path, 'bounce', 'bounce0.txt'), PRT, fmt='%.8f')
|
134 |
+
np.save(os.path.join(dir_path, 'bounce', 'face.npy'), F)
|
135 |
+
|
136 |
+
if __name__ == '__main__':
|
137 |
+
parser = argparse.ArgumentParser()
|
138 |
+
parser.add_argument('-i', '--input', type=str, default='/home/shunsuke/Downloads/rp_dennis_posed_004_OBJ')
|
139 |
+
parser.add_argument('-n', '--n_sample', type=int, default=40, help='squared root of number of sampling. the higher, the more accurate, but slower')
|
140 |
+
args = parser.parse_args()
|
141 |
+
|
142 |
+
testPRT(args.input)
|
PIFu/apps/render_data.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#from data.config import raw_dataset, render_dataset, archive_dataset, model_list, zip_path
|
2 |
+
|
3 |
+
from lib.renderer.camera import Camera
|
4 |
+
import numpy as np
|
5 |
+
from lib.renderer.mesh import load_obj_mesh, compute_tangent, compute_normal, load_obj_mesh_mtl
|
6 |
+
from lib.renderer.camera import Camera
|
7 |
+
import os
|
8 |
+
import cv2
|
9 |
+
import time
|
10 |
+
import math
|
11 |
+
import random
|
12 |
+
import pyexr
|
13 |
+
import argparse
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
|
17 |
+
def make_rotate(rx, ry, rz):
|
18 |
+
sinX = np.sin(rx)
|
19 |
+
sinY = np.sin(ry)
|
20 |
+
sinZ = np.sin(rz)
|
21 |
+
|
22 |
+
cosX = np.cos(rx)
|
23 |
+
cosY = np.cos(ry)
|
24 |
+
cosZ = np.cos(rz)
|
25 |
+
|
26 |
+
Rx = np.zeros((3,3))
|
27 |
+
Rx[0, 0] = 1.0
|
28 |
+
Rx[1, 1] = cosX
|
29 |
+
Rx[1, 2] = -sinX
|
30 |
+
Rx[2, 1] = sinX
|
31 |
+
Rx[2, 2] = cosX
|
32 |
+
|
33 |
+
Ry = np.zeros((3,3))
|
34 |
+
Ry[0, 0] = cosY
|
35 |
+
Ry[0, 2] = sinY
|
36 |
+
Ry[1, 1] = 1.0
|
37 |
+
Ry[2, 0] = -sinY
|
38 |
+
Ry[2, 2] = cosY
|
39 |
+
|
40 |
+
Rz = np.zeros((3,3))
|
41 |
+
Rz[0, 0] = cosZ
|
42 |
+
Rz[0, 1] = -sinZ
|
43 |
+
Rz[1, 0] = sinZ
|
44 |
+
Rz[1, 1] = cosZ
|
45 |
+
Rz[2, 2] = 1.0
|
46 |
+
|
47 |
+
R = np.matmul(np.matmul(Rz,Ry),Rx)
|
48 |
+
return R
|
49 |
+
|
50 |
+
def rotateSH(SH, R):
|
51 |
+
SHn = SH
|
52 |
+
|
53 |
+
# 1st order
|
54 |
+
SHn[1] = R[1,1]*SH[1] - R[1,2]*SH[2] + R[1,0]*SH[3]
|
55 |
+
SHn[2] = -R[2,1]*SH[1] + R[2,2]*SH[2] - R[2,0]*SH[3]
|
56 |
+
SHn[3] = R[0,1]*SH[1] - R[0,2]*SH[2] + R[0,0]*SH[3]
|
57 |
+
|
58 |
+
# 2nd order
|
59 |
+
SHn[4:,0] = rotateBand2(SH[4:,0],R)
|
60 |
+
SHn[4:,1] = rotateBand2(SH[4:,1],R)
|
61 |
+
SHn[4:,2] = rotateBand2(SH[4:,2],R)
|
62 |
+
|
63 |
+
return SHn
|
64 |
+
|
65 |
+
def rotateBand2(x, R):
|
66 |
+
s_c3 = 0.94617469575
|
67 |
+
s_c4 = -0.31539156525
|
68 |
+
s_c5 = 0.54627421529
|
69 |
+
|
70 |
+
s_c_scale = 1.0/0.91529123286551084
|
71 |
+
s_c_scale_inv = 0.91529123286551084
|
72 |
+
|
73 |
+
s_rc2 = 1.5853309190550713*s_c_scale
|
74 |
+
s_c4_div_c3 = s_c4/s_c3
|
75 |
+
s_c4_div_c3_x2 = (s_c4/s_c3)*2.0
|
76 |
+
|
77 |
+
s_scale_dst2 = s_c3 * s_c_scale_inv
|
78 |
+
s_scale_dst4 = s_c5 * s_c_scale_inv
|
79 |
+
|
80 |
+
sh0 = x[3] + x[4] + x[4] - x[1]
|
81 |
+
sh1 = x[0] + s_rc2*x[2] + x[3] + x[4]
|
82 |
+
sh2 = x[0]
|
83 |
+
sh3 = -x[3]
|
84 |
+
sh4 = -x[1]
|
85 |
+
|
86 |
+
r2x = R[0][0] + R[0][1]
|
87 |
+
r2y = R[1][0] + R[1][1]
|
88 |
+
r2z = R[2][0] + R[2][1]
|
89 |
+
|
90 |
+
r3x = R[0][0] + R[0][2]
|
91 |
+
r3y = R[1][0] + R[1][2]
|
92 |
+
r3z = R[2][0] + R[2][2]
|
93 |
+
|
94 |
+
r4x = R[0][1] + R[0][2]
|
95 |
+
r4y = R[1][1] + R[1][2]
|
96 |
+
r4z = R[2][1] + R[2][2]
|
97 |
+
|
98 |
+
sh0_x = sh0 * R[0][0]
|
99 |
+
sh0_y = sh0 * R[1][0]
|
100 |
+
d0 = sh0_x * R[1][0]
|
101 |
+
d1 = sh0_y * R[2][0]
|
102 |
+
d2 = sh0 * (R[2][0] * R[2][0] + s_c4_div_c3)
|
103 |
+
d3 = sh0_x * R[2][0]
|
104 |
+
d4 = sh0_x * R[0][0] - sh0_y * R[1][0]
|
105 |
+
|
106 |
+
sh1_x = sh1 * R[0][2]
|
107 |
+
sh1_y = sh1 * R[1][2]
|
108 |
+
d0 += sh1_x * R[1][2]
|
109 |
+
d1 += sh1_y * R[2][2]
|
110 |
+
d2 += sh1 * (R[2][2] * R[2][2] + s_c4_div_c3)
|
111 |
+
d3 += sh1_x * R[2][2]
|
112 |
+
d4 += sh1_x * R[0][2] - sh1_y * R[1][2]
|
113 |
+
|
114 |
+
sh2_x = sh2 * r2x
|
115 |
+
sh2_y = sh2 * r2y
|
116 |
+
d0 += sh2_x * r2y
|
117 |
+
d1 += sh2_y * r2z
|
118 |
+
d2 += sh2 * (r2z * r2z + s_c4_div_c3_x2)
|
119 |
+
d3 += sh2_x * r2z
|
120 |
+
d4 += sh2_x * r2x - sh2_y * r2y
|
121 |
+
|
122 |
+
sh3_x = sh3 * r3x
|
123 |
+
sh3_y = sh3 * r3y
|
124 |
+
d0 += sh3_x * r3y
|
125 |
+
d1 += sh3_y * r3z
|
126 |
+
d2 += sh3 * (r3z * r3z + s_c4_div_c3_x2)
|
127 |
+
d3 += sh3_x * r3z
|
128 |
+
d4 += sh3_x * r3x - sh3_y * r3y
|
129 |
+
|
130 |
+
sh4_x = sh4 * r4x
|
131 |
+
sh4_y = sh4 * r4y
|
132 |
+
d0 += sh4_x * r4y
|
133 |
+
d1 += sh4_y * r4z
|
134 |
+
d2 += sh4 * (r4z * r4z + s_c4_div_c3_x2)
|
135 |
+
d3 += sh4_x * r4z
|
136 |
+
d4 += sh4_x * r4x - sh4_y * r4y
|
137 |
+
|
138 |
+
dst = x
|
139 |
+
dst[0] = d0
|
140 |
+
dst[1] = -d1
|
141 |
+
dst[2] = d2 * s_scale_dst2
|
142 |
+
dst[3] = -d3
|
143 |
+
dst[4] = d4 * s_scale_dst4
|
144 |
+
|
145 |
+
return dst
|
146 |
+
|
147 |
+
def render_prt_ortho(out_path, folder_name, subject_name, shs, rndr, rndr_uv, im_size, angl_step=4, n_light=1, pitch=[0]):
|
148 |
+
cam = Camera(width=im_size, height=im_size)
|
149 |
+
cam.ortho_ratio = 0.4 * (512 / im_size)
|
150 |
+
cam.near = -100
|
151 |
+
cam.far = 100
|
152 |
+
cam.sanity_check()
|
153 |
+
|
154 |
+
# set path for obj, prt
|
155 |
+
mesh_file = os.path.join(folder_name, subject_name + '_100k.obj')
|
156 |
+
if not os.path.exists(mesh_file):
|
157 |
+
print('ERROR: obj file does not exist!!', mesh_file)
|
158 |
+
return
|
159 |
+
prt_file = os.path.join(folder_name, 'bounce', 'bounce0.txt')
|
160 |
+
if not os.path.exists(prt_file):
|
161 |
+
print('ERROR: prt file does not exist!!!', prt_file)
|
162 |
+
return
|
163 |
+
face_prt_file = os.path.join(folder_name, 'bounce', 'face.npy')
|
164 |
+
if not os.path.exists(face_prt_file):
|
165 |
+
print('ERROR: face prt file does not exist!!!', prt_file)
|
166 |
+
return
|
167 |
+
text_file = os.path.join(folder_name, 'tex', subject_name + '_dif_2k.jpg')
|
168 |
+
if not os.path.exists(text_file):
|
169 |
+
print('ERROR: dif file does not exist!!', text_file)
|
170 |
+
return
|
171 |
+
|
172 |
+
texture_image = cv2.imread(text_file)
|
173 |
+
texture_image = cv2.cvtColor(texture_image, cv2.COLOR_BGR2RGB)
|
174 |
+
|
175 |
+
vertices, faces, normals, faces_normals, textures, face_textures = load_obj_mesh(mesh_file, with_normal=True, with_texture=True)
|
176 |
+
vmin = vertices.min(0)
|
177 |
+
vmax = vertices.max(0)
|
178 |
+
up_axis = 1 if (vmax-vmin).argmax() == 1 else 2
|
179 |
+
|
180 |
+
vmed = np.median(vertices, 0)
|
181 |
+
vmed[up_axis] = 0.5*(vmax[up_axis]+vmin[up_axis])
|
182 |
+
y_scale = 180/(vmax[up_axis] - vmin[up_axis])
|
183 |
+
|
184 |
+
rndr.set_norm_mat(y_scale, vmed)
|
185 |
+
rndr_uv.set_norm_mat(y_scale, vmed)
|
186 |
+
|
187 |
+
tan, bitan = compute_tangent(vertices, faces, normals, textures, face_textures)
|
188 |
+
prt = np.loadtxt(prt_file)
|
189 |
+
face_prt = np.load(face_prt_file)
|
190 |
+
rndr.set_mesh(vertices, faces, normals, faces_normals, textures, face_textures, prt, face_prt, tan, bitan)
|
191 |
+
rndr.set_albedo(texture_image)
|
192 |
+
|
193 |
+
rndr_uv.set_mesh(vertices, faces, normals, faces_normals, textures, face_textures, prt, face_prt, tan, bitan)
|
194 |
+
rndr_uv.set_albedo(texture_image)
|
195 |
+
|
196 |
+
os.makedirs(os.path.join(out_path, 'GEO', 'OBJ', subject_name),exist_ok=True)
|
197 |
+
os.makedirs(os.path.join(out_path, 'PARAM', subject_name),exist_ok=True)
|
198 |
+
os.makedirs(os.path.join(out_path, 'RENDER', subject_name),exist_ok=True)
|
199 |
+
os.makedirs(os.path.join(out_path, 'MASK', subject_name),exist_ok=True)
|
200 |
+
os.makedirs(os.path.join(out_path, 'UV_RENDER', subject_name),exist_ok=True)
|
201 |
+
os.makedirs(os.path.join(out_path, 'UV_MASK', subject_name),exist_ok=True)
|
202 |
+
os.makedirs(os.path.join(out_path, 'UV_POS', subject_name),exist_ok=True)
|
203 |
+
os.makedirs(os.path.join(out_path, 'UV_NORMAL', subject_name),exist_ok=True)
|
204 |
+
|
205 |
+
if not os.path.exists(os.path.join(out_path, 'val.txt')):
|
206 |
+
f = open(os.path.join(out_path, 'val.txt'), 'w')
|
207 |
+
f.close()
|
208 |
+
|
209 |
+
# copy obj file
|
210 |
+
cmd = 'cp %s %s' % (mesh_file, os.path.join(out_path, 'GEO', 'OBJ', subject_name))
|
211 |
+
print(cmd)
|
212 |
+
os.system(cmd)
|
213 |
+
|
214 |
+
for p in pitch:
|
215 |
+
for y in tqdm(range(0, 360, angl_step)):
|
216 |
+
R = np.matmul(make_rotate(math.radians(p), 0, 0), make_rotate(0, math.radians(y), 0))
|
217 |
+
if up_axis == 2:
|
218 |
+
R = np.matmul(R, make_rotate(math.radians(90),0,0))
|
219 |
+
|
220 |
+
rndr.rot_matrix = R
|
221 |
+
rndr_uv.rot_matrix = R
|
222 |
+
rndr.set_camera(cam)
|
223 |
+
rndr_uv.set_camera(cam)
|
224 |
+
|
225 |
+
for j in range(n_light):
|
226 |
+
sh_id = random.randint(0,shs.shape[0]-1)
|
227 |
+
sh = shs[sh_id]
|
228 |
+
sh_angle = 0.2*np.pi*(random.random()-0.5)
|
229 |
+
sh = rotateSH(sh, make_rotate(0, sh_angle, 0).T)
|
230 |
+
|
231 |
+
dic = {'sh': sh, 'ortho_ratio': cam.ortho_ratio, 'scale': y_scale, 'center': vmed, 'R': R}
|
232 |
+
|
233 |
+
rndr.set_sh(sh)
|
234 |
+
rndr.analytic = False
|
235 |
+
rndr.use_inverse_depth = False
|
236 |
+
rndr.display()
|
237 |
+
|
238 |
+
out_all_f = rndr.get_color(0)
|
239 |
+
out_mask = out_all_f[:,:,3]
|
240 |
+
out_all_f = cv2.cvtColor(out_all_f, cv2.COLOR_RGBA2BGR)
|
241 |
+
|
242 |
+
np.save(os.path.join(out_path, 'PARAM', subject_name, '%d_%d_%02d.npy'%(y,p,j)),dic)
|
243 |
+
cv2.imwrite(os.path.join(out_path, 'RENDER', subject_name, '%d_%d_%02d.jpg'%(y,p,j)),255.0*out_all_f)
|
244 |
+
cv2.imwrite(os.path.join(out_path, 'MASK', subject_name, '%d_%d_%02d.png'%(y,p,j)),255.0*out_mask)
|
245 |
+
|
246 |
+
rndr_uv.set_sh(sh)
|
247 |
+
rndr_uv.analytic = False
|
248 |
+
rndr_uv.use_inverse_depth = False
|
249 |
+
rndr_uv.display()
|
250 |
+
|
251 |
+
uv_color = rndr_uv.get_color(0)
|
252 |
+
uv_color = cv2.cvtColor(uv_color, cv2.COLOR_RGBA2BGR)
|
253 |
+
cv2.imwrite(os.path.join(out_path, 'UV_RENDER', subject_name, '%d_%d_%02d.jpg'%(y,p,j)),255.0*uv_color)
|
254 |
+
|
255 |
+
if y == 0 and j == 0 and p == pitch[0]:
|
256 |
+
uv_pos = rndr_uv.get_color(1)
|
257 |
+
uv_mask = uv_pos[:,:,3]
|
258 |
+
cv2.imwrite(os.path.join(out_path, 'UV_MASK', subject_name, '00.png'),255.0*uv_mask)
|
259 |
+
|
260 |
+
data = {'default': uv_pos[:,:,:3]} # default is a reserved name
|
261 |
+
pyexr.write(os.path.join(out_path, 'UV_POS', subject_name, '00.exr'), data)
|
262 |
+
|
263 |
+
uv_nml = rndr_uv.get_color(2)
|
264 |
+
uv_nml = cv2.cvtColor(uv_nml, cv2.COLOR_RGBA2BGR)
|
265 |
+
cv2.imwrite(os.path.join(out_path, 'UV_NORMAL', subject_name, '00.png'),255.0*uv_nml)
|
266 |
+
|
267 |
+
|
268 |
+
if __name__ == '__main__':
|
269 |
+
shs = np.load('./env_sh.npy')
|
270 |
+
|
271 |
+
parser = argparse.ArgumentParser()
|
272 |
+
parser.add_argument('-i', '--input', type=str, default='/home/shunsuke/Downloads/rp_dennis_posed_004_OBJ')
|
273 |
+
parser.add_argument('-o', '--out_dir', type=str, default='/home/shunsuke/Documents/hf_human')
|
274 |
+
parser.add_argument('-m', '--ms_rate', type=int, default=1, help='higher ms rate results in less aliased output. MESA renderer only supports ms_rate=1.')
|
275 |
+
parser.add_argument('-e', '--egl', action='store_true', help='egl rendering option. use this when rendering with headless server with NVIDIA GPU')
|
276 |
+
parser.add_argument('-s', '--size', type=int, default=512, help='rendering image size')
|
277 |
+
args = parser.parse_args()
|
278 |
+
|
279 |
+
# NOTE: GL context has to be created before any other OpenGL function loads.
|
280 |
+
from lib.renderer.gl.init_gl import initialize_GL_context
|
281 |
+
initialize_GL_context(width=args.size, height=args.size, egl=args.egl)
|
282 |
+
|
283 |
+
from lib.renderer.gl.prt_render import PRTRender
|
284 |
+
rndr = PRTRender(width=args.size, height=args.size, ms_rate=args.ms_rate, egl=args.egl)
|
285 |
+
rndr_uv = PRTRender(width=args.size, height=args.size, uv_mode=True, egl=args.egl)
|
286 |
+
|
287 |
+
if args.input[-1] == '/':
|
288 |
+
args.input = args.input[:-1]
|
289 |
+
subject_name = args.input.split('/')[-1][:-4]
|
290 |
+
render_prt_ortho(args.out_dir, args.input, subject_name, shs, rndr, rndr_uv, args.size, 1, 1, pitch=[0])
|
PIFu/apps/train_color.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
5 |
+
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
6 |
+
|
7 |
+
import time
|
8 |
+
import json
|
9 |
+
import numpy as np
|
10 |
+
import cv2
|
11 |
+
import random
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
from lib.options import BaseOptions
|
18 |
+
from lib.mesh_util import *
|
19 |
+
from lib.sample_util import *
|
20 |
+
from lib.train_util import *
|
21 |
+
from lib.data import *
|
22 |
+
from lib.model import *
|
23 |
+
from lib.geometry import index
|
24 |
+
|
25 |
+
# get options
|
26 |
+
opt = BaseOptions().parse()
|
27 |
+
|
28 |
+
def train_color(opt):
|
29 |
+
# set cuda
|
30 |
+
cuda = torch.device('cuda:%d' % opt.gpu_id)
|
31 |
+
|
32 |
+
train_dataset = TrainDataset(opt, phase='train')
|
33 |
+
test_dataset = TrainDataset(opt, phase='test')
|
34 |
+
|
35 |
+
projection_mode = train_dataset.projection_mode
|
36 |
+
|
37 |
+
# create data loader
|
38 |
+
train_data_loader = DataLoader(train_dataset,
|
39 |
+
batch_size=opt.batch_size, shuffle=not opt.serial_batches,
|
40 |
+
num_workers=opt.num_threads, pin_memory=opt.pin_memory)
|
41 |
+
|
42 |
+
print('train data size: ', len(train_data_loader))
|
43 |
+
|
44 |
+
# NOTE: batch size should be 1 and use all the points for evaluation
|
45 |
+
test_data_loader = DataLoader(test_dataset,
|
46 |
+
batch_size=1, shuffle=False,
|
47 |
+
num_workers=opt.num_threads, pin_memory=opt.pin_memory)
|
48 |
+
print('test data size: ', len(test_data_loader))
|
49 |
+
|
50 |
+
# create net
|
51 |
+
netG = HGPIFuNet(opt, projection_mode).to(device=cuda)
|
52 |
+
|
53 |
+
lr = opt.learning_rate
|
54 |
+
|
55 |
+
# Always use resnet for color regression
|
56 |
+
netC = ResBlkPIFuNet(opt).to(device=cuda)
|
57 |
+
optimizerC = torch.optim.Adam(netC.parameters(), lr=opt.learning_rate)
|
58 |
+
|
59 |
+
def set_train():
|
60 |
+
netG.eval()
|
61 |
+
netC.train()
|
62 |
+
|
63 |
+
def set_eval():
|
64 |
+
netG.eval()
|
65 |
+
netC.eval()
|
66 |
+
|
67 |
+
print('Using NetworkG: ', netG.name, 'networkC: ', netC.name)
|
68 |
+
|
69 |
+
# load checkpoints
|
70 |
+
if opt.load_netG_checkpoint_path is not None:
|
71 |
+
print('loading for net G ...', opt.load_netG_checkpoint_path)
|
72 |
+
netG.load_state_dict(torch.load(opt.load_netG_checkpoint_path, map_location=cuda))
|
73 |
+
else:
|
74 |
+
model_path_G = '%s/%s/netG_latest' % (opt.checkpoints_path, opt.name)
|
75 |
+
print('loading for net G ...', model_path_G)
|
76 |
+
netG.load_state_dict(torch.load(model_path_G, map_location=cuda))
|
77 |
+
|
78 |
+
if opt.load_netC_checkpoint_path is not None:
|
79 |
+
print('loading for net C ...', opt.load_netC_checkpoint_path)
|
80 |
+
netC.load_state_dict(torch.load(opt.load_netC_checkpoint_path, map_location=cuda))
|
81 |
+
|
82 |
+
if opt.continue_train:
|
83 |
+
if opt.resume_epoch < 0:
|
84 |
+
model_path_C = '%s/%s/netC_latest' % (opt.checkpoints_path, opt.name)
|
85 |
+
else:
|
86 |
+
model_path_C = '%s/%s/netC_epoch_%d' % (opt.checkpoints_path, opt.name, opt.resume_epoch)
|
87 |
+
|
88 |
+
print('Resuming from ', model_path_C)
|
89 |
+
netC.load_state_dict(torch.load(model_path_C, map_location=cuda))
|
90 |
+
|
91 |
+
os.makedirs(opt.checkpoints_path, exist_ok=True)
|
92 |
+
os.makedirs(opt.results_path, exist_ok=True)
|
93 |
+
os.makedirs('%s/%s' % (opt.checkpoints_path, opt.name), exist_ok=True)
|
94 |
+
os.makedirs('%s/%s' % (opt.results_path, opt.name), exist_ok=True)
|
95 |
+
|
96 |
+
opt_log = os.path.join(opt.results_path, opt.name, 'opt.txt')
|
97 |
+
with open(opt_log, 'w') as outfile:
|
98 |
+
outfile.write(json.dumps(vars(opt), indent=2))
|
99 |
+
|
100 |
+
# training
|
101 |
+
start_epoch = 0 if not opt.continue_train else max(opt.resume_epoch,0)
|
102 |
+
for epoch in range(start_epoch, opt.num_epoch):
|
103 |
+
epoch_start_time = time.time()
|
104 |
+
|
105 |
+
set_train()
|
106 |
+
iter_data_time = time.time()
|
107 |
+
for train_idx, train_data in enumerate(train_data_loader):
|
108 |
+
iter_start_time = time.time()
|
109 |
+
# retrieve the data
|
110 |
+
image_tensor = train_data['img'].to(device=cuda)
|
111 |
+
calib_tensor = train_data['calib'].to(device=cuda)
|
112 |
+
color_sample_tensor = train_data['color_samples'].to(device=cuda)
|
113 |
+
|
114 |
+
image_tensor, calib_tensor = reshape_multiview_tensors(image_tensor, calib_tensor)
|
115 |
+
|
116 |
+
if opt.num_views > 1:
|
117 |
+
color_sample_tensor = reshape_sample_tensor(color_sample_tensor, opt.num_views)
|
118 |
+
|
119 |
+
rgb_tensor = train_data['rgbs'].to(device=cuda)
|
120 |
+
|
121 |
+
with torch.no_grad():
|
122 |
+
netG.filter(image_tensor)
|
123 |
+
resC, error = netC.forward(image_tensor, netG.get_im_feat(), color_sample_tensor, calib_tensor, labels=rgb_tensor)
|
124 |
+
|
125 |
+
optimizerC.zero_grad()
|
126 |
+
error.backward()
|
127 |
+
optimizerC.step()
|
128 |
+
|
129 |
+
iter_net_time = time.time()
|
130 |
+
eta = ((iter_net_time - epoch_start_time) / (train_idx + 1)) * len(train_data_loader) - (
|
131 |
+
iter_net_time - epoch_start_time)
|
132 |
+
|
133 |
+
if train_idx % opt.freq_plot == 0:
|
134 |
+
print(
|
135 |
+
'Name: {0} | Epoch: {1} | {2}/{3} | Err: {4:.06f} | LR: {5:.06f} | dataT: {6:.05f} | netT: {7:.05f} | ETA: {8:02d}:{9:02d}'.format(
|
136 |
+
opt.name, epoch, train_idx, len(train_data_loader),
|
137 |
+
error.item(),
|
138 |
+
lr,
|
139 |
+
iter_start_time - iter_data_time,
|
140 |
+
iter_net_time - iter_start_time, int(eta // 60),
|
141 |
+
int(eta - 60 * (eta // 60))))
|
142 |
+
|
143 |
+
if train_idx % opt.freq_save == 0 and train_idx != 0:
|
144 |
+
torch.save(netC.state_dict(), '%s/%s/netC_latest' % (opt.checkpoints_path, opt.name))
|
145 |
+
torch.save(netC.state_dict(), '%s/%s/netC_epoch_%d' % (opt.checkpoints_path, opt.name, epoch))
|
146 |
+
|
147 |
+
if train_idx % opt.freq_save_ply == 0:
|
148 |
+
save_path = '%s/%s/pred_col.ply' % (opt.results_path, opt.name)
|
149 |
+
rgb = resC[0].transpose(0, 1).cpu() * 0.5 + 0.5
|
150 |
+
points = color_sample_tensor[0].transpose(0, 1).cpu()
|
151 |
+
save_samples_rgb(save_path, points.detach().numpy(), rgb.detach().numpy())
|
152 |
+
|
153 |
+
iter_data_time = time.time()
|
154 |
+
|
155 |
+
#### test
|
156 |
+
with torch.no_grad():
|
157 |
+
set_eval()
|
158 |
+
|
159 |
+
if not opt.no_num_eval:
|
160 |
+
test_losses = {}
|
161 |
+
print('calc error (test) ...')
|
162 |
+
test_color_error = calc_error_color(opt, netG, netC, cuda, test_dataset, 100)
|
163 |
+
print('eval test | color error:', test_color_error)
|
164 |
+
test_losses['test_color'] = test_color_error
|
165 |
+
|
166 |
+
print('calc error (train) ...')
|
167 |
+
train_dataset.is_train = False
|
168 |
+
train_color_error = calc_error_color(opt, netG, netC, cuda, train_dataset, 100)
|
169 |
+
train_dataset.is_train = True
|
170 |
+
print('eval train | color error:', train_color_error)
|
171 |
+
test_losses['train_color'] = train_color_error
|
172 |
+
|
173 |
+
if not opt.no_gen_mesh:
|
174 |
+
print('generate mesh (test) ...')
|
175 |
+
for gen_idx in tqdm(range(opt.num_gen_mesh_test)):
|
176 |
+
test_data = random.choice(test_dataset)
|
177 |
+
save_path = '%s/%s/test_eval_epoch%d_%s.obj' % (
|
178 |
+
opt.results_path, opt.name, epoch, test_data['name'])
|
179 |
+
gen_mesh_color(opt, netG, netC, cuda, test_data, save_path)
|
180 |
+
|
181 |
+
print('generate mesh (train) ...')
|
182 |
+
train_dataset.is_train = False
|
183 |
+
for gen_idx in tqdm(range(opt.num_gen_mesh_test)):
|
184 |
+
train_data = random.choice(train_dataset)
|
185 |
+
save_path = '%s/%s/train_eval_epoch%d_%s.obj' % (
|
186 |
+
opt.results_path, opt.name, epoch, train_data['name'])
|
187 |
+
gen_mesh_color(opt, netG, netC, cuda, train_data, save_path)
|
188 |
+
train_dataset.is_train = True
|
189 |
+
|
190 |
+
if __name__ == '__main__':
|
191 |
+
train_color(opt)
|
PIFu/apps/train_shape.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
5 |
+
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
6 |
+
|
7 |
+
import time
|
8 |
+
import json
|
9 |
+
import numpy as np
|
10 |
+
import cv2
|
11 |
+
import random
|
12 |
+
import torch
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from lib.options import BaseOptions
|
17 |
+
from lib.mesh_util import *
|
18 |
+
from lib.sample_util import *
|
19 |
+
from lib.train_util import *
|
20 |
+
from lib.data import *
|
21 |
+
from lib.model import *
|
22 |
+
from lib.geometry import index
|
23 |
+
|
24 |
+
# get options
|
25 |
+
opt = BaseOptions().parse()
|
26 |
+
|
27 |
+
def train(opt):
|
28 |
+
# set cuda
|
29 |
+
cuda = torch.device('cuda:%d' % opt.gpu_id)
|
30 |
+
|
31 |
+
train_dataset = TrainDataset(opt, phase='train')
|
32 |
+
test_dataset = TrainDataset(opt, phase='test')
|
33 |
+
|
34 |
+
projection_mode = train_dataset.projection_mode
|
35 |
+
|
36 |
+
# create data loader
|
37 |
+
train_data_loader = DataLoader(train_dataset,
|
38 |
+
batch_size=opt.batch_size, shuffle=not opt.serial_batches,
|
39 |
+
num_workers=opt.num_threads, pin_memory=opt.pin_memory)
|
40 |
+
|
41 |
+
print('train data size: ', len(train_data_loader))
|
42 |
+
|
43 |
+
# NOTE: batch size should be 1 and use all the points for evaluation
|
44 |
+
test_data_loader = DataLoader(test_dataset,
|
45 |
+
batch_size=1, shuffle=False,
|
46 |
+
num_workers=opt.num_threads, pin_memory=opt.pin_memory)
|
47 |
+
print('test data size: ', len(test_data_loader))
|
48 |
+
|
49 |
+
# create net
|
50 |
+
netG = HGPIFuNet(opt, projection_mode).to(device=cuda)
|
51 |
+
optimizerG = torch.optim.RMSprop(netG.parameters(), lr=opt.learning_rate, momentum=0, weight_decay=0)
|
52 |
+
lr = opt.learning_rate
|
53 |
+
print('Using Network: ', netG.name)
|
54 |
+
|
55 |
+
def set_train():
|
56 |
+
netG.train()
|
57 |
+
|
58 |
+
def set_eval():
|
59 |
+
netG.eval()
|
60 |
+
|
61 |
+
# load checkpoints
|
62 |
+
if opt.load_netG_checkpoint_path is not None:
|
63 |
+
print('loading for net G ...', opt.load_netG_checkpoint_path)
|
64 |
+
netG.load_state_dict(torch.load(opt.load_netG_checkpoint_path, map_location=cuda))
|
65 |
+
|
66 |
+
if opt.continue_train:
|
67 |
+
if opt.resume_epoch < 0:
|
68 |
+
model_path = '%s/%s/netG_latest' % (opt.checkpoints_path, opt.name)
|
69 |
+
else:
|
70 |
+
model_path = '%s/%s/netG_epoch_%d' % (opt.checkpoints_path, opt.name, opt.resume_epoch)
|
71 |
+
print('Resuming from ', model_path)
|
72 |
+
netG.load_state_dict(torch.load(model_path, map_location=cuda))
|
73 |
+
|
74 |
+
os.makedirs(opt.checkpoints_path, exist_ok=True)
|
75 |
+
os.makedirs(opt.results_path, exist_ok=True)
|
76 |
+
os.makedirs('%s/%s' % (opt.checkpoints_path, opt.name), exist_ok=True)
|
77 |
+
os.makedirs('%s/%s' % (opt.results_path, opt.name), exist_ok=True)
|
78 |
+
|
79 |
+
opt_log = os.path.join(opt.results_path, opt.name, 'opt.txt')
|
80 |
+
with open(opt_log, 'w') as outfile:
|
81 |
+
outfile.write(json.dumps(vars(opt), indent=2))
|
82 |
+
|
83 |
+
# training
|
84 |
+
start_epoch = 0 if not opt.continue_train else max(opt.resume_epoch,0)
|
85 |
+
for epoch in range(start_epoch, opt.num_epoch):
|
86 |
+
epoch_start_time = time.time()
|
87 |
+
|
88 |
+
set_train()
|
89 |
+
iter_data_time = time.time()
|
90 |
+
for train_idx, train_data in enumerate(train_data_loader):
|
91 |
+
iter_start_time = time.time()
|
92 |
+
|
93 |
+
# retrieve the data
|
94 |
+
image_tensor = train_data['img'].to(device=cuda)
|
95 |
+
calib_tensor = train_data['calib'].to(device=cuda)
|
96 |
+
sample_tensor = train_data['samples'].to(device=cuda)
|
97 |
+
|
98 |
+
image_tensor, calib_tensor = reshape_multiview_tensors(image_tensor, calib_tensor)
|
99 |
+
|
100 |
+
if opt.num_views > 1:
|
101 |
+
sample_tensor = reshape_sample_tensor(sample_tensor, opt.num_views)
|
102 |
+
|
103 |
+
label_tensor = train_data['labels'].to(device=cuda)
|
104 |
+
|
105 |
+
res, error = netG.forward(image_tensor, sample_tensor, calib_tensor, labels=label_tensor)
|
106 |
+
|
107 |
+
optimizerG.zero_grad()
|
108 |
+
error.backward()
|
109 |
+
optimizerG.step()
|
110 |
+
|
111 |
+
iter_net_time = time.time()
|
112 |
+
eta = ((iter_net_time - epoch_start_time) / (train_idx + 1)) * len(train_data_loader) - (
|
113 |
+
iter_net_time - epoch_start_time)
|
114 |
+
|
115 |
+
if train_idx % opt.freq_plot == 0:
|
116 |
+
print(
|
117 |
+
'Name: {0} | Epoch: {1} | {2}/{3} | Err: {4:.06f} | LR: {5:.06f} | Sigma: {6:.02f} | dataT: {7:.05f} | netT: {8:.05f} | ETA: {9:02d}:{10:02d}'.format(
|
118 |
+
opt.name, epoch, train_idx, len(train_data_loader), error.item(), lr, opt.sigma,
|
119 |
+
iter_start_time - iter_data_time,
|
120 |
+
iter_net_time - iter_start_time, int(eta // 60),
|
121 |
+
int(eta - 60 * (eta // 60))))
|
122 |
+
|
123 |
+
if train_idx % opt.freq_save == 0 and train_idx != 0:
|
124 |
+
torch.save(netG.state_dict(), '%s/%s/netG_latest' % (opt.checkpoints_path, opt.name))
|
125 |
+
torch.save(netG.state_dict(), '%s/%s/netG_epoch_%d' % (opt.checkpoints_path, opt.name, epoch))
|
126 |
+
|
127 |
+
if train_idx % opt.freq_save_ply == 0:
|
128 |
+
save_path = '%s/%s/pred.ply' % (opt.results_path, opt.name)
|
129 |
+
r = res[0].cpu()
|
130 |
+
points = sample_tensor[0].transpose(0, 1).cpu()
|
131 |
+
save_samples_truncted_prob(save_path, points.detach().numpy(), r.detach().numpy())
|
132 |
+
|
133 |
+
iter_data_time = time.time()
|
134 |
+
|
135 |
+
# update learning rate
|
136 |
+
lr = adjust_learning_rate(optimizerG, epoch, lr, opt.schedule, opt.gamma)
|
137 |
+
|
138 |
+
#### test
|
139 |
+
with torch.no_grad():
|
140 |
+
set_eval()
|
141 |
+
|
142 |
+
if not opt.no_num_eval:
|
143 |
+
test_losses = {}
|
144 |
+
print('calc error (test) ...')
|
145 |
+
test_errors = calc_error(opt, netG, cuda, test_dataset, 100)
|
146 |
+
print('eval test MSE: {0:06f} IOU: {1:06f} prec: {2:06f} recall: {3:06f}'.format(*test_errors))
|
147 |
+
MSE, IOU, prec, recall = test_errors
|
148 |
+
test_losses['MSE(test)'] = MSE
|
149 |
+
test_losses['IOU(test)'] = IOU
|
150 |
+
test_losses['prec(test)'] = prec
|
151 |
+
test_losses['recall(test)'] = recall
|
152 |
+
|
153 |
+
print('calc error (train) ...')
|
154 |
+
train_dataset.is_train = False
|
155 |
+
train_errors = calc_error(opt, netG, cuda, train_dataset, 100)
|
156 |
+
train_dataset.is_train = True
|
157 |
+
print('eval train MSE: {0:06f} IOU: {1:06f} prec: {2:06f} recall: {3:06f}'.format(*train_errors))
|
158 |
+
MSE, IOU, prec, recall = train_errors
|
159 |
+
test_losses['MSE(train)'] = MSE
|
160 |
+
test_losses['IOU(train)'] = IOU
|
161 |
+
test_losses['prec(train)'] = prec
|
162 |
+
test_losses['recall(train)'] = recall
|
163 |
+
|
164 |
+
if not opt.no_gen_mesh:
|
165 |
+
print('generate mesh (test) ...')
|
166 |
+
for gen_idx in tqdm(range(opt.num_gen_mesh_test)):
|
167 |
+
test_data = random.choice(test_dataset)
|
168 |
+
save_path = '%s/%s/test_eval_epoch%d_%s.obj' % (
|
169 |
+
opt.results_path, opt.name, epoch, test_data['name'])
|
170 |
+
gen_mesh(opt, netG, cuda, test_data, save_path)
|
171 |
+
|
172 |
+
print('generate mesh (train) ...')
|
173 |
+
train_dataset.is_train = False
|
174 |
+
for gen_idx in tqdm(range(opt.num_gen_mesh_test)):
|
175 |
+
train_data = random.choice(train_dataset)
|
176 |
+
save_path = '%s/%s/train_eval_epoch%d_%s.obj' % (
|
177 |
+
opt.results_path, opt.name, epoch, train_data['name'])
|
178 |
+
gen_mesh(opt, netG, cuda, train_data, save_path)
|
179 |
+
train_dataset.is_train = True
|
180 |
+
|
181 |
+
|
182 |
+
if __name__ == '__main__':
|
183 |
+
train(opt)
|
PIFu/env_sh.npy
ADDED
Binary file (52 kB). View file
|
|
PIFu/environment.yml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: PIFu
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- opencv
|
7 |
+
- pytorch
|
8 |
+
- json
|
9 |
+
- pyexr
|
10 |
+
- cv2
|
11 |
+
- PIL
|
12 |
+
- skimage
|
13 |
+
- tqdm
|
14 |
+
- pyembree
|
15 |
+
- shapely
|
16 |
+
- rtree
|
17 |
+
- xxhash
|
18 |
+
- trimesh
|
19 |
+
- PyOpenGL
|
PIFu/lib/__init__.py
ADDED
File without changes
|
PIFu/lib/colab_util.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from skimage.io import imread
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
from tqdm import tqdm_notebook as tqdm
|
8 |
+
import base64
|
9 |
+
from IPython.display import HTML
|
10 |
+
|
11 |
+
# Util function for loading meshes
|
12 |
+
from pytorch3d.io import load_objs_as_meshes
|
13 |
+
|
14 |
+
from IPython.display import HTML
|
15 |
+
from base64 import b64encode
|
16 |
+
|
17 |
+
# Data structures and functions for rendering
|
18 |
+
from pytorch3d.structures import Meshes
|
19 |
+
from pytorch3d.renderer import (
|
20 |
+
look_at_view_transform,
|
21 |
+
OpenGLOrthographicCameras,
|
22 |
+
PointLights,
|
23 |
+
DirectionalLights,
|
24 |
+
Materials,
|
25 |
+
RasterizationSettings,
|
26 |
+
MeshRenderer,
|
27 |
+
MeshRasterizer,
|
28 |
+
SoftPhongShader,
|
29 |
+
HardPhongShader,
|
30 |
+
TexturesVertex
|
31 |
+
)
|
32 |
+
|
33 |
+
def set_renderer():
|
34 |
+
# Setup
|
35 |
+
device = torch.device("cuda:0")
|
36 |
+
torch.cuda.set_device(device)
|
37 |
+
|
38 |
+
# Initialize an OpenGL perspective camera.
|
39 |
+
R, T = look_at_view_transform(2.0, 0, 180)
|
40 |
+
cameras = OpenGLOrthographicCameras(device=device, R=R, T=T)
|
41 |
+
|
42 |
+
raster_settings = RasterizationSettings(
|
43 |
+
image_size=512,
|
44 |
+
blur_radius=0.0,
|
45 |
+
faces_per_pixel=1,
|
46 |
+
bin_size = None,
|
47 |
+
max_faces_per_bin = None
|
48 |
+
)
|
49 |
+
|
50 |
+
lights = PointLights(device=device, location=((2.0, 2.0, 2.0),))
|
51 |
+
|
52 |
+
renderer = MeshRenderer(
|
53 |
+
rasterizer=MeshRasterizer(
|
54 |
+
cameras=cameras,
|
55 |
+
raster_settings=raster_settings
|
56 |
+
),
|
57 |
+
shader=HardPhongShader(
|
58 |
+
device=device,
|
59 |
+
cameras=cameras,
|
60 |
+
lights=lights
|
61 |
+
)
|
62 |
+
)
|
63 |
+
return renderer
|
64 |
+
|
65 |
+
def get_verts_rgb_colors(obj_path):
|
66 |
+
rgb_colors = []
|
67 |
+
|
68 |
+
f = open(obj_path)
|
69 |
+
lines = f.readlines()
|
70 |
+
for line in lines:
|
71 |
+
ls = line.split(' ')
|
72 |
+
if len(ls) == 7:
|
73 |
+
rgb_colors.append(ls[-3:])
|
74 |
+
|
75 |
+
return np.array(rgb_colors, dtype='float32')[None, :, :]
|
76 |
+
|
77 |
+
def generate_video_from_obj(obj_path, video_path, renderer):
|
78 |
+
# Setup
|
79 |
+
device = torch.device("cuda:0")
|
80 |
+
torch.cuda.set_device(device)
|
81 |
+
|
82 |
+
# Load obj file
|
83 |
+
verts_rgb_colors = get_verts_rgb_colors(obj_path)
|
84 |
+
verts_rgb_colors = torch.from_numpy(verts_rgb_colors).to(device)
|
85 |
+
textures = TexturesVertex(verts_features=verts_rgb_colors)
|
86 |
+
wo_textures = TexturesVertex(verts_features=torch.ones_like(verts_rgb_colors)*0.75)
|
87 |
+
|
88 |
+
# Load obj
|
89 |
+
mesh = load_objs_as_meshes([obj_path], device=device)
|
90 |
+
|
91 |
+
# Set mesh
|
92 |
+
vers = mesh._verts_list
|
93 |
+
faces = mesh._faces_list
|
94 |
+
mesh_w_tex = Meshes(vers, faces, textures)
|
95 |
+
mesh_wo_tex = Meshes(vers, faces, wo_textures)
|
96 |
+
|
97 |
+
# create VideoWriter
|
98 |
+
fourcc = cv2. VideoWriter_fourcc(*'MP4V')
|
99 |
+
out = cv2.VideoWriter(video_path, fourcc, 20.0, (1024,512))
|
100 |
+
|
101 |
+
for i in tqdm(range(90)):
|
102 |
+
R, T = look_at_view_transform(1.8, 0, i*4, device=device)
|
103 |
+
images_w_tex = renderer(mesh_w_tex, R=R, T=T)
|
104 |
+
images_w_tex = np.clip(images_w_tex[0, ..., :3].cpu().numpy(), 0.0, 1.0)[:, :, ::-1] * 255
|
105 |
+
images_wo_tex = renderer(mesh_wo_tex, R=R, T=T)
|
106 |
+
images_wo_tex = np.clip(images_wo_tex[0, ..., :3].cpu().numpy(), 0.0, 1.0)[:, :, ::-1] * 255
|
107 |
+
image = np.concatenate([images_w_tex, images_wo_tex], axis=1)
|
108 |
+
out.write(image.astype('uint8'))
|
109 |
+
out.release()
|
110 |
+
|
111 |
+
def video(path):
|
112 |
+
mp4 = open(path,'rb').read()
|
113 |
+
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
|
114 |
+
return HTML('<video width=500 controls loop> <source src="%s" type="video/mp4"></video>' % data_url)
|
PIFu/lib/data/BaseDataset.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
import random
|
3 |
+
|
4 |
+
|
5 |
+
class BaseDataset(Dataset):
|
6 |
+
'''
|
7 |
+
This is the Base Datasets.
|
8 |
+
Itself does nothing and is not runnable.
|
9 |
+
Check self.get_item function to see what it should return.
|
10 |
+
'''
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def modify_commandline_options(parser, is_train):
|
14 |
+
return parser
|
15 |
+
|
16 |
+
def __init__(self, opt, phase='train'):
|
17 |
+
self.opt = opt
|
18 |
+
self.is_train = self.phase == 'train'
|
19 |
+
self.projection_mode = 'orthogonal' # Declare projection mode here
|
20 |
+
|
21 |
+
def __len__(self):
|
22 |
+
return 0
|
23 |
+
|
24 |
+
def get_item(self, index):
|
25 |
+
# In case of a missing file or IO error, switch to a random sample instead
|
26 |
+
try:
|
27 |
+
res = {
|
28 |
+
'name': None, # name of this subject
|
29 |
+
'b_min': None, # Bounding box (x_min, y_min, z_min) of target space
|
30 |
+
'b_max': None, # Bounding box (x_max, y_max, z_max) of target space
|
31 |
+
|
32 |
+
'samples': None, # [3, N] samples
|
33 |
+
'labels': None, # [1, N] labels
|
34 |
+
|
35 |
+
'img': None, # [num_views, C, H, W] input images
|
36 |
+
'calib': None, # [num_views, 4, 4] calibration matrix
|
37 |
+
'extrinsic': None, # [num_views, 4, 4] extrinsic matrix
|
38 |
+
'mask': None, # [num_views, 1, H, W] segmentation masks
|
39 |
+
}
|
40 |
+
return res
|
41 |
+
except:
|
42 |
+
print("Requested index %s has missing files. Using a random sample instead." % index)
|
43 |
+
return self.get_item(index=random.randint(0, self.__len__() - 1))
|
44 |
+
|
45 |
+
def __getitem__(self, index):
|
46 |
+
return self.get_item(index)
|
PIFu/lib/data/EvalDataset.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
from PIL import Image, ImageOps
|
7 |
+
import cv2
|
8 |
+
import torch
|
9 |
+
from PIL.ImageFilter import GaussianBlur
|
10 |
+
import trimesh
|
11 |
+
import cv2
|
12 |
+
|
13 |
+
|
14 |
+
class EvalDataset(Dataset):
|
15 |
+
@staticmethod
|
16 |
+
def modify_commandline_options(parser):
|
17 |
+
return parser
|
18 |
+
|
19 |
+
def __init__(self, opt, root=None):
|
20 |
+
self.opt = opt
|
21 |
+
self.projection_mode = 'orthogonal'
|
22 |
+
|
23 |
+
# Path setup
|
24 |
+
self.root = self.opt.dataroot
|
25 |
+
if root is not None:
|
26 |
+
self.root = root
|
27 |
+
self.RENDER = os.path.join(self.root, 'RENDER')
|
28 |
+
self.MASK = os.path.join(self.root, 'MASK')
|
29 |
+
self.PARAM = os.path.join(self.root, 'PARAM')
|
30 |
+
self.OBJ = os.path.join(self.root, 'GEO', 'OBJ')
|
31 |
+
|
32 |
+
self.phase = 'val'
|
33 |
+
self.load_size = self.opt.loadSize
|
34 |
+
|
35 |
+
self.num_views = self.opt.num_views
|
36 |
+
|
37 |
+
self.max_view_angle = 360
|
38 |
+
self.interval = 1
|
39 |
+
self.subjects = self.get_subjects()
|
40 |
+
|
41 |
+
# PIL to tensor
|
42 |
+
self.to_tensor = transforms.Compose([
|
43 |
+
transforms.Resize(self.load_size),
|
44 |
+
transforms.ToTensor(),
|
45 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
46 |
+
])
|
47 |
+
|
48 |
+
def get_subjects(self):
|
49 |
+
var_file = os.path.join(self.root, 'val.txt')
|
50 |
+
if os.path.exists(var_file):
|
51 |
+
var_subjects = np.loadtxt(var_file, dtype=str)
|
52 |
+
return sorted(list(var_subjects))
|
53 |
+
all_subjects = os.listdir(self.RENDER)
|
54 |
+
return sorted(list(all_subjects))
|
55 |
+
|
56 |
+
def __len__(self):
|
57 |
+
return len(self.subjects) * self.max_view_angle // self.interval
|
58 |
+
|
59 |
+
def get_render(self, subject, num_views, view_id=None, random_sample=False):
|
60 |
+
'''
|
61 |
+
Return the render data
|
62 |
+
:param subject: subject name
|
63 |
+
:param num_views: how many views to return
|
64 |
+
:param view_id: the first view_id. If None, select a random one.
|
65 |
+
:return:
|
66 |
+
'img': [num_views, C, W, H] images
|
67 |
+
'calib': [num_views, 4, 4] calibration matrix
|
68 |
+
'extrinsic': [num_views, 4, 4] extrinsic matrix
|
69 |
+
'mask': [num_views, 1, W, H] masks
|
70 |
+
'''
|
71 |
+
# For now we only have pitch = 00. Hard code it here
|
72 |
+
pitch = 0
|
73 |
+
# Select a random view_id from self.max_view_angle if not given
|
74 |
+
if view_id is None:
|
75 |
+
view_id = np.random.randint(self.max_view_angle)
|
76 |
+
# The ids are an even distribution of num_views around view_id
|
77 |
+
view_ids = [(view_id + self.max_view_angle // num_views * offset) % self.max_view_angle
|
78 |
+
for offset in range(num_views)]
|
79 |
+
if random_sample:
|
80 |
+
view_ids = np.random.choice(self.max_view_angle, num_views, replace=False)
|
81 |
+
|
82 |
+
calib_list = []
|
83 |
+
render_list = []
|
84 |
+
mask_list = []
|
85 |
+
extrinsic_list = []
|
86 |
+
|
87 |
+
for vid in view_ids:
|
88 |
+
param_path = os.path.join(self.PARAM, subject, '%d_%02d.npy' % (vid, pitch))
|
89 |
+
render_path = os.path.join(self.RENDER, subject, '%d_%02d.jpg' % (vid, pitch))
|
90 |
+
mask_path = os.path.join(self.MASK, subject, '%d_%02d.png' % (vid, pitch))
|
91 |
+
|
92 |
+
# loading calibration data
|
93 |
+
param = np.load(param_path)
|
94 |
+
# pixel unit / world unit
|
95 |
+
ortho_ratio = param.item().get('ortho_ratio')
|
96 |
+
# world unit / model unit
|
97 |
+
scale = param.item().get('scale')
|
98 |
+
# camera center world coordinate
|
99 |
+
center = param.item().get('center')
|
100 |
+
# model rotation
|
101 |
+
R = param.item().get('R')
|
102 |
+
|
103 |
+
translate = -np.matmul(R, center).reshape(3, 1)
|
104 |
+
extrinsic = np.concatenate([R, translate], axis=1)
|
105 |
+
extrinsic = np.concatenate([extrinsic, np.array([0, 0, 0, 1]).reshape(1, 4)], 0)
|
106 |
+
# Match camera space to image pixel space
|
107 |
+
scale_intrinsic = np.identity(4)
|
108 |
+
scale_intrinsic[0, 0] = scale / ortho_ratio
|
109 |
+
scale_intrinsic[1, 1] = -scale / ortho_ratio
|
110 |
+
scale_intrinsic[2, 2] = -scale / ortho_ratio
|
111 |
+
# Match image pixel space to image uv space
|
112 |
+
uv_intrinsic = np.identity(4)
|
113 |
+
uv_intrinsic[0, 0] = 1.0 / float(self.opt.loadSize // 2)
|
114 |
+
uv_intrinsic[1, 1] = 1.0 / float(self.opt.loadSize // 2)
|
115 |
+
uv_intrinsic[2, 2] = 1.0 / float(self.opt.loadSize // 2)
|
116 |
+
# Transform under image pixel space
|
117 |
+
trans_intrinsic = np.identity(4)
|
118 |
+
|
119 |
+
mask = Image.open(mask_path).convert('L')
|
120 |
+
render = Image.open(render_path).convert('RGB')
|
121 |
+
|
122 |
+
intrinsic = np.matmul(trans_intrinsic, np.matmul(uv_intrinsic, scale_intrinsic))
|
123 |
+
calib = torch.Tensor(np.matmul(intrinsic, extrinsic)).float()
|
124 |
+
extrinsic = torch.Tensor(extrinsic).float()
|
125 |
+
|
126 |
+
mask = transforms.Resize(self.load_size)(mask)
|
127 |
+
mask = transforms.ToTensor()(mask).float()
|
128 |
+
mask_list.append(mask)
|
129 |
+
|
130 |
+
render = self.to_tensor(render)
|
131 |
+
render = mask.expand_as(render) * render
|
132 |
+
|
133 |
+
render_list.append(render)
|
134 |
+
calib_list.append(calib)
|
135 |
+
extrinsic_list.append(extrinsic)
|
136 |
+
|
137 |
+
return {
|
138 |
+
'img': torch.stack(render_list, dim=0),
|
139 |
+
'calib': torch.stack(calib_list, dim=0),
|
140 |
+
'extrinsic': torch.stack(extrinsic_list, dim=0),
|
141 |
+
'mask': torch.stack(mask_list, dim=0)
|
142 |
+
}
|
143 |
+
|
144 |
+
def get_item(self, index):
|
145 |
+
# In case of a missing file or IO error, switch to a random sample instead
|
146 |
+
try:
|
147 |
+
sid = index % len(self.subjects)
|
148 |
+
vid = (index // len(self.subjects)) * self.interval
|
149 |
+
# name of the subject 'rp_xxxx_xxx'
|
150 |
+
subject = self.subjects[sid]
|
151 |
+
res = {
|
152 |
+
'name': subject,
|
153 |
+
'mesh_path': os.path.join(self.OBJ, subject + '.obj'),
|
154 |
+
'sid': sid,
|
155 |
+
'vid': vid,
|
156 |
+
}
|
157 |
+
render_data = self.get_render(subject, num_views=self.num_views, view_id=vid,
|
158 |
+
random_sample=self.opt.random_multiview)
|
159 |
+
res.update(render_data)
|
160 |
+
return res
|
161 |
+
except Exception as e:
|
162 |
+
print(e)
|
163 |
+
return self.get_item(index=random.randint(0, self.__len__() - 1))
|
164 |
+
|
165 |
+
def __getitem__(self, index):
|
166 |
+
return self.get_item(index)
|
PIFu/lib/data/TrainDataset.py
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
from PIL import Image, ImageOps
|
7 |
+
import cv2
|
8 |
+
import torch
|
9 |
+
from PIL.ImageFilter import GaussianBlur
|
10 |
+
import trimesh
|
11 |
+
import logging
|
12 |
+
|
13 |
+
log = logging.getLogger('trimesh')
|
14 |
+
log.setLevel(40)
|
15 |
+
|
16 |
+
def load_trimesh(root_dir):
|
17 |
+
folders = os.listdir(root_dir)
|
18 |
+
meshs = {}
|
19 |
+
for i, f in enumerate(folders):
|
20 |
+
sub_name = f
|
21 |
+
meshs[sub_name] = trimesh.load(os.path.join(root_dir, f, '%s_100k.obj' % sub_name))
|
22 |
+
|
23 |
+
return meshs
|
24 |
+
|
25 |
+
def save_samples_truncted_prob(fname, points, prob):
|
26 |
+
'''
|
27 |
+
Save the visualization of sampling to a ply file.
|
28 |
+
Red points represent positive predictions.
|
29 |
+
Green points represent negative predictions.
|
30 |
+
:param fname: File name to save
|
31 |
+
:param points: [N, 3] array of points
|
32 |
+
:param prob: [N, 1] array of predictions in the range [0~1]
|
33 |
+
:return:
|
34 |
+
'''
|
35 |
+
r = (prob > 0.5).reshape([-1, 1]) * 255
|
36 |
+
g = (prob < 0.5).reshape([-1, 1]) * 255
|
37 |
+
b = np.zeros(r.shape)
|
38 |
+
|
39 |
+
to_save = np.concatenate([points, r, g, b], axis=-1)
|
40 |
+
return np.savetxt(fname,
|
41 |
+
to_save,
|
42 |
+
fmt='%.6f %.6f %.6f %d %d %d',
|
43 |
+
comments='',
|
44 |
+
header=(
|
45 |
+
'ply\nformat ascii 1.0\nelement vertex {:d}\nproperty float x\nproperty float y\nproperty float z\nproperty uchar red\nproperty uchar green\nproperty uchar blue\nend_header').format(
|
46 |
+
points.shape[0])
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
class TrainDataset(Dataset):
|
51 |
+
@staticmethod
|
52 |
+
def modify_commandline_options(parser, is_train):
|
53 |
+
return parser
|
54 |
+
|
55 |
+
def __init__(self, opt, phase='train'):
|
56 |
+
self.opt = opt
|
57 |
+
self.projection_mode = 'orthogonal'
|
58 |
+
|
59 |
+
# Path setup
|
60 |
+
self.root = self.opt.dataroot
|
61 |
+
self.RENDER = os.path.join(self.root, 'RENDER')
|
62 |
+
self.MASK = os.path.join(self.root, 'MASK')
|
63 |
+
self.PARAM = os.path.join(self.root, 'PARAM')
|
64 |
+
self.UV_MASK = os.path.join(self.root, 'UV_MASK')
|
65 |
+
self.UV_NORMAL = os.path.join(self.root, 'UV_NORMAL')
|
66 |
+
self.UV_RENDER = os.path.join(self.root, 'UV_RENDER')
|
67 |
+
self.UV_POS = os.path.join(self.root, 'UV_POS')
|
68 |
+
self.OBJ = os.path.join(self.root, 'GEO', 'OBJ')
|
69 |
+
|
70 |
+
self.B_MIN = np.array([-128, -28, -128])
|
71 |
+
self.B_MAX = np.array([128, 228, 128])
|
72 |
+
|
73 |
+
self.is_train = (phase == 'train')
|
74 |
+
self.load_size = self.opt.loadSize
|
75 |
+
|
76 |
+
self.num_views = self.opt.num_views
|
77 |
+
|
78 |
+
self.num_sample_inout = self.opt.num_sample_inout
|
79 |
+
self.num_sample_color = self.opt.num_sample_color
|
80 |
+
|
81 |
+
self.yaw_list = list(range(0,360,1))
|
82 |
+
self.pitch_list = [0]
|
83 |
+
self.subjects = self.get_subjects()
|
84 |
+
|
85 |
+
# PIL to tensor
|
86 |
+
self.to_tensor = transforms.Compose([
|
87 |
+
transforms.Resize(self.load_size),
|
88 |
+
transforms.ToTensor(),
|
89 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
90 |
+
])
|
91 |
+
|
92 |
+
# augmentation
|
93 |
+
self.aug_trans = transforms.Compose([
|
94 |
+
transforms.ColorJitter(brightness=opt.aug_bri, contrast=opt.aug_con, saturation=opt.aug_sat,
|
95 |
+
hue=opt.aug_hue)
|
96 |
+
])
|
97 |
+
|
98 |
+
self.mesh_dic = load_trimesh(self.OBJ)
|
99 |
+
|
100 |
+
def get_subjects(self):
|
101 |
+
all_subjects = os.listdir(self.RENDER)
|
102 |
+
var_subjects = np.loadtxt(os.path.join(self.root, 'val.txt'), dtype=str)
|
103 |
+
if len(var_subjects) == 0:
|
104 |
+
return all_subjects
|
105 |
+
|
106 |
+
if self.is_train:
|
107 |
+
return sorted(list(set(all_subjects) - set(var_subjects)))
|
108 |
+
else:
|
109 |
+
return sorted(list(var_subjects))
|
110 |
+
|
111 |
+
def __len__(self):
|
112 |
+
return len(self.subjects) * len(self.yaw_list) * len(self.pitch_list)
|
113 |
+
|
114 |
+
def get_render(self, subject, num_views, yid=0, pid=0, random_sample=False):
|
115 |
+
'''
|
116 |
+
Return the render data
|
117 |
+
:param subject: subject name
|
118 |
+
:param num_views: how many views to return
|
119 |
+
:param view_id: the first view_id. If None, select a random one.
|
120 |
+
:return:
|
121 |
+
'img': [num_views, C, W, H] images
|
122 |
+
'calib': [num_views, 4, 4] calibration matrix
|
123 |
+
'extrinsic': [num_views, 4, 4] extrinsic matrix
|
124 |
+
'mask': [num_views, 1, W, H] masks
|
125 |
+
'''
|
126 |
+
pitch = self.pitch_list[pid]
|
127 |
+
|
128 |
+
# The ids are an even distribution of num_views around view_id
|
129 |
+
view_ids = [self.yaw_list[(yid + len(self.yaw_list) // num_views * offset) % len(self.yaw_list)]
|
130 |
+
for offset in range(num_views)]
|
131 |
+
if random_sample:
|
132 |
+
view_ids = np.random.choice(self.yaw_list, num_views, replace=False)
|
133 |
+
|
134 |
+
calib_list = []
|
135 |
+
render_list = []
|
136 |
+
mask_list = []
|
137 |
+
extrinsic_list = []
|
138 |
+
|
139 |
+
for vid in view_ids:
|
140 |
+
param_path = os.path.join(self.PARAM, subject, '%d_%d_%02d.npy' % (vid, pitch, 0))
|
141 |
+
render_path = os.path.join(self.RENDER, subject, '%d_%d_%02d.jpg' % (vid, pitch, 0))
|
142 |
+
mask_path = os.path.join(self.MASK, subject, '%d_%d_%02d.png' % (vid, pitch, 0))
|
143 |
+
|
144 |
+
# loading calibration data
|
145 |
+
param = np.load(param_path, allow_pickle=True)
|
146 |
+
# pixel unit / world unit
|
147 |
+
ortho_ratio = param.item().get('ortho_ratio')
|
148 |
+
# world unit / model unit
|
149 |
+
scale = param.item().get('scale')
|
150 |
+
# camera center world coordinate
|
151 |
+
center = param.item().get('center')
|
152 |
+
# model rotation
|
153 |
+
R = param.item().get('R')
|
154 |
+
|
155 |
+
translate = -np.matmul(R, center).reshape(3, 1)
|
156 |
+
extrinsic = np.concatenate([R, translate], axis=1)
|
157 |
+
extrinsic = np.concatenate([extrinsic, np.array([0, 0, 0, 1]).reshape(1, 4)], 0)
|
158 |
+
# Match camera space to image pixel space
|
159 |
+
scale_intrinsic = np.identity(4)
|
160 |
+
scale_intrinsic[0, 0] = scale / ortho_ratio
|
161 |
+
scale_intrinsic[1, 1] = -scale / ortho_ratio
|
162 |
+
scale_intrinsic[2, 2] = scale / ortho_ratio
|
163 |
+
# Match image pixel space to image uv space
|
164 |
+
uv_intrinsic = np.identity(4)
|
165 |
+
uv_intrinsic[0, 0] = 1.0 / float(self.opt.loadSize // 2)
|
166 |
+
uv_intrinsic[1, 1] = 1.0 / float(self.opt.loadSize // 2)
|
167 |
+
uv_intrinsic[2, 2] = 1.0 / float(self.opt.loadSize // 2)
|
168 |
+
# Transform under image pixel space
|
169 |
+
trans_intrinsic = np.identity(4)
|
170 |
+
|
171 |
+
mask = Image.open(mask_path).convert('L')
|
172 |
+
render = Image.open(render_path).convert('RGB')
|
173 |
+
|
174 |
+
if self.is_train:
|
175 |
+
# Pad images
|
176 |
+
pad_size = int(0.1 * self.load_size)
|
177 |
+
render = ImageOps.expand(render, pad_size, fill=0)
|
178 |
+
mask = ImageOps.expand(mask, pad_size, fill=0)
|
179 |
+
|
180 |
+
w, h = render.size
|
181 |
+
th, tw = self.load_size, self.load_size
|
182 |
+
|
183 |
+
# random flip
|
184 |
+
if self.opt.random_flip and np.random.rand() > 0.5:
|
185 |
+
scale_intrinsic[0, 0] *= -1
|
186 |
+
render = transforms.RandomHorizontalFlip(p=1.0)(render)
|
187 |
+
mask = transforms.RandomHorizontalFlip(p=1.0)(mask)
|
188 |
+
|
189 |
+
# random scale
|
190 |
+
if self.opt.random_scale:
|
191 |
+
rand_scale = random.uniform(0.9, 1.1)
|
192 |
+
w = int(rand_scale * w)
|
193 |
+
h = int(rand_scale * h)
|
194 |
+
render = render.resize((w, h), Image.BILINEAR)
|
195 |
+
mask = mask.resize((w, h), Image.NEAREST)
|
196 |
+
scale_intrinsic *= rand_scale
|
197 |
+
scale_intrinsic[3, 3] = 1
|
198 |
+
|
199 |
+
# random translate in the pixel space
|
200 |
+
if self.opt.random_trans:
|
201 |
+
dx = random.randint(-int(round((w - tw) / 10.)),
|
202 |
+
int(round((w - tw) / 10.)))
|
203 |
+
dy = random.randint(-int(round((h - th) / 10.)),
|
204 |
+
int(round((h - th) / 10.)))
|
205 |
+
else:
|
206 |
+
dx = 0
|
207 |
+
dy = 0
|
208 |
+
|
209 |
+
trans_intrinsic[0, 3] = -dx / float(self.opt.loadSize // 2)
|
210 |
+
trans_intrinsic[1, 3] = -dy / float(self.opt.loadSize // 2)
|
211 |
+
|
212 |
+
x1 = int(round((w - tw) / 2.)) + dx
|
213 |
+
y1 = int(round((h - th) / 2.)) + dy
|
214 |
+
|
215 |
+
render = render.crop((x1, y1, x1 + tw, y1 + th))
|
216 |
+
mask = mask.crop((x1, y1, x1 + tw, y1 + th))
|
217 |
+
|
218 |
+
render = self.aug_trans(render)
|
219 |
+
|
220 |
+
# random blur
|
221 |
+
if self.opt.aug_blur > 0.00001:
|
222 |
+
blur = GaussianBlur(np.random.uniform(0, self.opt.aug_blur))
|
223 |
+
render = render.filter(blur)
|
224 |
+
|
225 |
+
intrinsic = np.matmul(trans_intrinsic, np.matmul(uv_intrinsic, scale_intrinsic))
|
226 |
+
calib = torch.Tensor(np.matmul(intrinsic, extrinsic)).float()
|
227 |
+
extrinsic = torch.Tensor(extrinsic).float()
|
228 |
+
|
229 |
+
mask = transforms.Resize(self.load_size)(mask)
|
230 |
+
mask = transforms.ToTensor()(mask).float()
|
231 |
+
mask_list.append(mask)
|
232 |
+
|
233 |
+
render = self.to_tensor(render)
|
234 |
+
render = mask.expand_as(render) * render
|
235 |
+
|
236 |
+
render_list.append(render)
|
237 |
+
calib_list.append(calib)
|
238 |
+
extrinsic_list.append(extrinsic)
|
239 |
+
|
240 |
+
return {
|
241 |
+
'img': torch.stack(render_list, dim=0),
|
242 |
+
'calib': torch.stack(calib_list, dim=0),
|
243 |
+
'extrinsic': torch.stack(extrinsic_list, dim=0),
|
244 |
+
'mask': torch.stack(mask_list, dim=0)
|
245 |
+
}
|
246 |
+
|
247 |
+
def select_sampling_method(self, subject):
|
248 |
+
if not self.is_train:
|
249 |
+
random.seed(1991)
|
250 |
+
np.random.seed(1991)
|
251 |
+
torch.manual_seed(1991)
|
252 |
+
mesh = self.mesh_dic[subject]
|
253 |
+
surface_points, _ = trimesh.sample.sample_surface(mesh, 4 * self.num_sample_inout)
|
254 |
+
sample_points = surface_points + np.random.normal(scale=self.opt.sigma, size=surface_points.shape)
|
255 |
+
|
256 |
+
# add random points within image space
|
257 |
+
length = self.B_MAX - self.B_MIN
|
258 |
+
random_points = np.random.rand(self.num_sample_inout // 4, 3) * length + self.B_MIN
|
259 |
+
sample_points = np.concatenate([sample_points, random_points], 0)
|
260 |
+
np.random.shuffle(sample_points)
|
261 |
+
|
262 |
+
inside = mesh.contains(sample_points)
|
263 |
+
inside_points = sample_points[inside]
|
264 |
+
outside_points = sample_points[np.logical_not(inside)]
|
265 |
+
|
266 |
+
nin = inside_points.shape[0]
|
267 |
+
inside_points = inside_points[
|
268 |
+
:self.num_sample_inout // 2] if nin > self.num_sample_inout // 2 else inside_points
|
269 |
+
outside_points = outside_points[
|
270 |
+
:self.num_sample_inout // 2] if nin > self.num_sample_inout // 2 else outside_points[
|
271 |
+
:(self.num_sample_inout - nin)]
|
272 |
+
|
273 |
+
samples = np.concatenate([inside_points, outside_points], 0).T
|
274 |
+
labels = np.concatenate([np.ones((1, inside_points.shape[0])), np.zeros((1, outside_points.shape[0]))], 1)
|
275 |
+
|
276 |
+
# save_samples_truncted_prob('out.ply', samples.T, labels.T)
|
277 |
+
# exit()
|
278 |
+
|
279 |
+
samples = torch.Tensor(samples).float()
|
280 |
+
labels = torch.Tensor(labels).float()
|
281 |
+
|
282 |
+
del mesh
|
283 |
+
|
284 |
+
return {
|
285 |
+
'samples': samples,
|
286 |
+
'labels': labels
|
287 |
+
}
|
288 |
+
|
289 |
+
|
290 |
+
def get_color_sampling(self, subject, yid, pid=0):
|
291 |
+
yaw = self.yaw_list[yid]
|
292 |
+
pitch = self.pitch_list[pid]
|
293 |
+
uv_render_path = os.path.join(self.UV_RENDER, subject, '%d_%d_%02d.jpg' % (yaw, pitch, 0))
|
294 |
+
uv_mask_path = os.path.join(self.UV_MASK, subject, '%02d.png' % (0))
|
295 |
+
uv_pos_path = os.path.join(self.UV_POS, subject, '%02d.exr' % (0))
|
296 |
+
uv_normal_path = os.path.join(self.UV_NORMAL, subject, '%02d.png' % (0))
|
297 |
+
|
298 |
+
# Segmentation mask for the uv render.
|
299 |
+
# [H, W] bool
|
300 |
+
uv_mask = cv2.imread(uv_mask_path)
|
301 |
+
uv_mask = uv_mask[:, :, 0] != 0
|
302 |
+
# UV render. each pixel is the color of the point.
|
303 |
+
# [H, W, 3] 0 ~ 1 float
|
304 |
+
uv_render = cv2.imread(uv_render_path)
|
305 |
+
uv_render = cv2.cvtColor(uv_render, cv2.COLOR_BGR2RGB) / 255.0
|
306 |
+
|
307 |
+
# Normal render. each pixel is the surface normal of the point.
|
308 |
+
# [H, W, 3] -1 ~ 1 float
|
309 |
+
uv_normal = cv2.imread(uv_normal_path)
|
310 |
+
uv_normal = cv2.cvtColor(uv_normal, cv2.COLOR_BGR2RGB) / 255.0
|
311 |
+
uv_normal = 2.0 * uv_normal - 1.0
|
312 |
+
# Position render. each pixel is the xyz coordinates of the point
|
313 |
+
uv_pos = cv2.imread(uv_pos_path, 2 | 4)[:, :, ::-1]
|
314 |
+
|
315 |
+
### In these few lines we flattern the masks, positions, and normals
|
316 |
+
uv_mask = uv_mask.reshape((-1))
|
317 |
+
uv_pos = uv_pos.reshape((-1, 3))
|
318 |
+
uv_render = uv_render.reshape((-1, 3))
|
319 |
+
uv_normal = uv_normal.reshape((-1, 3))
|
320 |
+
|
321 |
+
surface_points = uv_pos[uv_mask]
|
322 |
+
surface_colors = uv_render[uv_mask]
|
323 |
+
surface_normal = uv_normal[uv_mask]
|
324 |
+
|
325 |
+
if self.num_sample_color:
|
326 |
+
sample_list = random.sample(range(0, surface_points.shape[0] - 1), self.num_sample_color)
|
327 |
+
surface_points = surface_points[sample_list].T
|
328 |
+
surface_colors = surface_colors[sample_list].T
|
329 |
+
surface_normal = surface_normal[sample_list].T
|
330 |
+
|
331 |
+
# Samples are around the true surface with an offset
|
332 |
+
normal = torch.Tensor(surface_normal).float()
|
333 |
+
samples = torch.Tensor(surface_points).float() \
|
334 |
+
+ torch.normal(mean=torch.zeros((1, normal.size(1))), std=self.opt.sigma).expand_as(normal) * normal
|
335 |
+
|
336 |
+
# Normalized to [-1, 1]
|
337 |
+
rgbs_color = 2.0 * torch.Tensor(surface_colors).float() - 1.0
|
338 |
+
|
339 |
+
return {
|
340 |
+
'color_samples': samples,
|
341 |
+
'rgbs': rgbs_color
|
342 |
+
}
|
343 |
+
|
344 |
+
def get_item(self, index):
|
345 |
+
# In case of a missing file or IO error, switch to a random sample instead
|
346 |
+
# try:
|
347 |
+
sid = index % len(self.subjects)
|
348 |
+
tmp = index // len(self.subjects)
|
349 |
+
yid = tmp % len(self.yaw_list)
|
350 |
+
pid = tmp // len(self.yaw_list)
|
351 |
+
|
352 |
+
# name of the subject 'rp_xxxx_xxx'
|
353 |
+
subject = self.subjects[sid]
|
354 |
+
res = {
|
355 |
+
'name': subject,
|
356 |
+
'mesh_path': os.path.join(self.OBJ, subject + '.obj'),
|
357 |
+
'sid': sid,
|
358 |
+
'yid': yid,
|
359 |
+
'pid': pid,
|
360 |
+
'b_min': self.B_MIN,
|
361 |
+
'b_max': self.B_MAX,
|
362 |
+
}
|
363 |
+
render_data = self.get_render(subject, num_views=self.num_views, yid=yid, pid=pid,
|
364 |
+
random_sample=self.opt.random_multiview)
|
365 |
+
res.update(render_data)
|
366 |
+
|
367 |
+
if self.opt.num_sample_inout:
|
368 |
+
sample_data = self.select_sampling_method(subject)
|
369 |
+
res.update(sample_data)
|
370 |
+
|
371 |
+
# img = np.uint8((np.transpose(render_data['img'][0].numpy(), (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0)
|
372 |
+
# rot = render_data['calib'][0,:3, :3]
|
373 |
+
# trans = render_data['calib'][0,:3, 3:4]
|
374 |
+
# pts = torch.addmm(trans, rot, sample_data['samples'][:, sample_data['labels'][0] > 0.5]) # [3, N]
|
375 |
+
# pts = 0.5 * (pts.numpy().T + 1.0) * render_data['img'].size(2)
|
376 |
+
# for p in pts:
|
377 |
+
# img = cv2.circle(img, (p[0], p[1]), 2, (0,255,0), -1)
|
378 |
+
# cv2.imshow('test', img)
|
379 |
+
# cv2.waitKey(1)
|
380 |
+
|
381 |
+
if self.num_sample_color:
|
382 |
+
color_data = self.get_color_sampling(subject, yid=yid, pid=pid)
|
383 |
+
res.update(color_data)
|
384 |
+
return res
|
385 |
+
# except Exception as e:
|
386 |
+
# print(e)
|
387 |
+
# return self.get_item(index=random.randint(0, self.__len__() - 1))
|
388 |
+
|
389 |
+
def __getitem__(self, index):
|
390 |
+
return self.get_item(index)
|
PIFu/lib/data/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .EvalDataset import EvalDataset
|
2 |
+
from .TrainDataset import TrainDataset
|
PIFu/lib/ext_transform.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from skimage.filters import gaussian
|
5 |
+
import torch
|
6 |
+
from PIL import Image, ImageFilter
|
7 |
+
|
8 |
+
|
9 |
+
class RandomVerticalFlip(object):
|
10 |
+
def __call__(self, img):
|
11 |
+
if random.random() < 0.5:
|
12 |
+
return img.transpose(Image.FLIP_TOP_BOTTOM)
|
13 |
+
return img
|
14 |
+
|
15 |
+
|
16 |
+
class DeNormalize(object):
|
17 |
+
def __init__(self, mean, std):
|
18 |
+
self.mean = mean
|
19 |
+
self.std = std
|
20 |
+
|
21 |
+
def __call__(self, tensor):
|
22 |
+
for t, m, s in zip(tensor, self.mean, self.std):
|
23 |
+
t.mul_(s).add_(m)
|
24 |
+
return tensor
|
25 |
+
|
26 |
+
|
27 |
+
class MaskToTensor(object):
|
28 |
+
def __call__(self, img):
|
29 |
+
return torch.from_numpy(np.array(img, dtype=np.int32)).long()
|
30 |
+
|
31 |
+
|
32 |
+
class FreeScale(object):
|
33 |
+
def __init__(self, size, interpolation=Image.BILINEAR):
|
34 |
+
self.size = tuple(reversed(size)) # size: (h, w)
|
35 |
+
self.interpolation = interpolation
|
36 |
+
|
37 |
+
def __call__(self, img):
|
38 |
+
return img.resize(self.size, self.interpolation)
|
39 |
+
|
40 |
+
|
41 |
+
class FlipChannels(object):
|
42 |
+
def __call__(self, img):
|
43 |
+
img = np.array(img)[:, :, ::-1]
|
44 |
+
return Image.fromarray(img.astype(np.uint8))
|
45 |
+
|
46 |
+
|
47 |
+
class RandomGaussianBlur(object):
|
48 |
+
def __call__(self, img):
|
49 |
+
sigma = 0.15 + random.random() * 1.15
|
50 |
+
blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True)
|
51 |
+
blurred_img *= 255
|
52 |
+
return Image.fromarray(blurred_img.astype(np.uint8))
|
53 |
+
|
54 |
+
# Lighting data augmentation take from here - https://github.com/eladhoffer/convNet.pytorch/blob/master/preprocess.py
|
55 |
+
|
56 |
+
|
57 |
+
class Lighting(object):
|
58 |
+
"""Lighting noise(AlexNet - style PCA - based noise)"""
|
59 |
+
|
60 |
+
def __init__(self, alphastd,
|
61 |
+
eigval=(0.2175, 0.0188, 0.0045),
|
62 |
+
eigvec=((-0.5675, 0.7192, 0.4009),
|
63 |
+
(-0.5808, -0.0045, -0.8140),
|
64 |
+
(-0.5836, -0.6948, 0.4203))):
|
65 |
+
self.alphastd = alphastd
|
66 |
+
self.eigval = torch.Tensor(eigval)
|
67 |
+
self.eigvec = torch.Tensor(eigvec)
|
68 |
+
|
69 |
+
def __call__(self, img):
|
70 |
+
if self.alphastd == 0:
|
71 |
+
return img
|
72 |
+
|
73 |
+
alpha = img.new().resize_(3).normal_(0, self.alphastd)
|
74 |
+
rgb = self.eigvec.type_as(img).clone()\
|
75 |
+
.mul(alpha.view(1, 3).expand(3, 3))\
|
76 |
+
.mul(self.eigval.view(1, 3).expand(3, 3))\
|
77 |
+
.sum(1).squeeze()
|
78 |
+
return img.add(rgb.view(3, 1, 1).expand_as(img))
|
PIFu/lib/geometry.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def index(feat, uv):
|
5 |
+
'''
|
6 |
+
|
7 |
+
:param feat: [B, C, H, W] image features
|
8 |
+
:param uv: [B, 2, N] uv coordinates in the image plane, range [-1, 1]
|
9 |
+
:return: [B, C, N] image features at the uv coordinates
|
10 |
+
'''
|
11 |
+
uv = uv.transpose(1, 2) # [B, N, 2]
|
12 |
+
uv = uv.unsqueeze(2) # [B, N, 1, 2]
|
13 |
+
# NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample
|
14 |
+
# for old versions, simply remove the aligned_corners argument.
|
15 |
+
samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True) # [B, C, N, 1]
|
16 |
+
return samples[:, :, :, 0] # [B, C, N]
|
17 |
+
|
18 |
+
|
19 |
+
def orthogonal(points, calibrations, transforms=None):
|
20 |
+
'''
|
21 |
+
Compute the orthogonal projections of 3D points into the image plane by given projection matrix
|
22 |
+
:param points: [B, 3, N] Tensor of 3D points
|
23 |
+
:param calibrations: [B, 4, 4] Tensor of projection matrix
|
24 |
+
:param transforms: [B, 2, 3] Tensor of image transform matrix
|
25 |
+
:return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane
|
26 |
+
'''
|
27 |
+
rot = calibrations[:, :3, :3]
|
28 |
+
trans = calibrations[:, :3, 3:4]
|
29 |
+
pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
|
30 |
+
if transforms is not None:
|
31 |
+
scale = transforms[:2, :2]
|
32 |
+
shift = transforms[:2, 2:3]
|
33 |
+
pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :])
|
34 |
+
return pts
|
35 |
+
|
36 |
+
|
37 |
+
def perspective(points, calibrations, transforms=None):
|
38 |
+
'''
|
39 |
+
Compute the perspective projections of 3D points into the image plane by given projection matrix
|
40 |
+
:param points: [Bx3xN] Tensor of 3D points
|
41 |
+
:param calibrations: [Bx4x4] Tensor of projection matrix
|
42 |
+
:param transforms: [Bx2x3] Tensor of image transform matrix
|
43 |
+
:return: xy: [Bx2xN] Tensor of xy coordinates in the image plane
|
44 |
+
'''
|
45 |
+
rot = calibrations[:, :3, :3]
|
46 |
+
trans = calibrations[:, :3, 3:4]
|
47 |
+
homo = torch.baddbmm(trans, rot, points) # [B, 3, N]
|
48 |
+
xy = homo[:, :2, :] / homo[:, 2:3, :]
|
49 |
+
if transforms is not None:
|
50 |
+
scale = transforms[:2, :2]
|
51 |
+
shift = transforms[:2, 2:3]
|
52 |
+
xy = torch.baddbmm(shift, scale, xy)
|
53 |
+
|
54 |
+
xyz = torch.cat([xy, homo[:, 2:3, :]], 1)
|
55 |
+
return xyz
|
PIFu/lib/mesh_util.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from skimage import measure
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from .sdf import create_grid, eval_grid_octree, eval_grid
|
5 |
+
from skimage import measure
|
6 |
+
|
7 |
+
|
8 |
+
def reconstruction(net, cuda, calib_tensor,
|
9 |
+
resolution, b_min, b_max,
|
10 |
+
use_octree=False, num_samples=10000, transform=None):
|
11 |
+
'''
|
12 |
+
Reconstruct meshes from sdf predicted by the network.
|
13 |
+
:param net: a BasePixImpNet object. call image filter beforehead.
|
14 |
+
:param cuda: cuda device
|
15 |
+
:param calib_tensor: calibration tensor
|
16 |
+
:param resolution: resolution of the grid cell
|
17 |
+
:param b_min: bounding box corner [x_min, y_min, z_min]
|
18 |
+
:param b_max: bounding box corner [x_max, y_max, z_max]
|
19 |
+
:param use_octree: whether to use octree acceleration
|
20 |
+
:param num_samples: how many points to query each gpu iteration
|
21 |
+
:return: marching cubes results.
|
22 |
+
'''
|
23 |
+
# First we create a grid by resolution
|
24 |
+
# and transforming matrix for grid coordinates to real world xyz
|
25 |
+
coords, mat = create_grid(resolution, resolution, resolution,
|
26 |
+
b_min, b_max, transform=transform)
|
27 |
+
|
28 |
+
# Then we define the lambda function for cell evaluation
|
29 |
+
def eval_func(points):
|
30 |
+
points = np.expand_dims(points, axis=0)
|
31 |
+
points = np.repeat(points, net.num_views, axis=0)
|
32 |
+
samples = torch.from_numpy(points).to(device=cuda).float()
|
33 |
+
net.query(samples, calib_tensor)
|
34 |
+
pred = net.get_preds()[0][0]
|
35 |
+
return pred.detach().cpu().numpy()
|
36 |
+
|
37 |
+
# Then we evaluate the grid
|
38 |
+
if use_octree:
|
39 |
+
sdf = eval_grid_octree(coords, eval_func, num_samples=num_samples)
|
40 |
+
else:
|
41 |
+
sdf = eval_grid(coords, eval_func, num_samples=num_samples)
|
42 |
+
|
43 |
+
# Finally we do marching cubes
|
44 |
+
try:
|
45 |
+
verts, faces, normals, values = measure.marching_cubes_lewiner(sdf, 0.5)
|
46 |
+
# transform verts into world coordinate system
|
47 |
+
verts = np.matmul(mat[:3, :3], verts.T) + mat[:3, 3:4]
|
48 |
+
verts = verts.T
|
49 |
+
return verts, faces, normals, values
|
50 |
+
except:
|
51 |
+
print('error cannot marching cubes')
|
52 |
+
return -1
|
53 |
+
|
54 |
+
|
55 |
+
def save_obj_mesh(mesh_path, verts, faces):
|
56 |
+
file = open(mesh_path, 'w')
|
57 |
+
|
58 |
+
for v in verts:
|
59 |
+
file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
|
60 |
+
for f in faces:
|
61 |
+
f_plus = f + 1
|
62 |
+
file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1]))
|
63 |
+
file.close()
|
64 |
+
|
65 |
+
|
66 |
+
def save_obj_mesh_with_color(mesh_path, verts, faces, colors):
|
67 |
+
file = open(mesh_path, 'w')
|
68 |
+
|
69 |
+
for idx, v in enumerate(verts):
|
70 |
+
c = colors[idx]
|
71 |
+
file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' % (v[0], v[1], v[2], c[0], c[1], c[2]))
|
72 |
+
for f in faces:
|
73 |
+
f_plus = f + 1
|
74 |
+
file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1]))
|
75 |
+
file.close()
|
76 |
+
|
77 |
+
|
78 |
+
def save_obj_mesh_with_uv(mesh_path, verts, faces, uvs):
|
79 |
+
file = open(mesh_path, 'w')
|
80 |
+
|
81 |
+
for idx, v in enumerate(verts):
|
82 |
+
vt = uvs[idx]
|
83 |
+
file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
|
84 |
+
file.write('vt %.4f %.4f\n' % (vt[0], vt[1]))
|
85 |
+
|
86 |
+
for f in faces:
|
87 |
+
f_plus = f + 1
|
88 |
+
file.write('f %d/%d %d/%d %d/%d\n' % (f_plus[0], f_plus[0],
|
89 |
+
f_plus[2], f_plus[2],
|
90 |
+
f_plus[1], f_plus[1]))
|
91 |
+
file.close()
|
PIFu/lib/model/BasePIFuNet.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from ..geometry import index, orthogonal, perspective
|
6 |
+
|
7 |
+
class BasePIFuNet(nn.Module):
|
8 |
+
def __init__(self,
|
9 |
+
projection_mode='orthogonal',
|
10 |
+
error_term=nn.MSELoss(),
|
11 |
+
):
|
12 |
+
"""
|
13 |
+
:param projection_mode:
|
14 |
+
Either orthogonal or perspective.
|
15 |
+
It will call the corresponding function for projection.
|
16 |
+
:param error_term:
|
17 |
+
nn Loss between the predicted [B, Res, N] and the label [B, Res, N]
|
18 |
+
"""
|
19 |
+
super(BasePIFuNet, self).__init__()
|
20 |
+
self.name = 'base'
|
21 |
+
|
22 |
+
self.error_term = error_term
|
23 |
+
|
24 |
+
self.index = index
|
25 |
+
self.projection = orthogonal if projection_mode == 'orthogonal' else perspective
|
26 |
+
|
27 |
+
self.preds = None
|
28 |
+
self.labels = None
|
29 |
+
|
30 |
+
def forward(self, points, images, calibs, transforms=None):
|
31 |
+
'''
|
32 |
+
:param points: [B, 3, N] world space coordinates of points
|
33 |
+
:param images: [B, C, H, W] input images
|
34 |
+
:param calibs: [B, 3, 4] calibration matrices for each image
|
35 |
+
:param transforms: Optional [B, 2, 3] image space coordinate transforms
|
36 |
+
:return: [B, Res, N] predictions for each point
|
37 |
+
'''
|
38 |
+
self.filter(images)
|
39 |
+
self.query(points, calibs, transforms)
|
40 |
+
return self.get_preds()
|
41 |
+
|
42 |
+
def filter(self, images):
|
43 |
+
'''
|
44 |
+
Filter the input images
|
45 |
+
store all intermediate features.
|
46 |
+
:param images: [B, C, H, W] input images
|
47 |
+
'''
|
48 |
+
None
|
49 |
+
|
50 |
+
def query(self, points, calibs, transforms=None, labels=None):
|
51 |
+
'''
|
52 |
+
Given 3D points, query the network predictions for each point.
|
53 |
+
Image features should be pre-computed before this call.
|
54 |
+
store all intermediate features.
|
55 |
+
query() function may behave differently during training/testing.
|
56 |
+
:param points: [B, 3, N] world space coordinates of points
|
57 |
+
:param calibs: [B, 3, 4] calibration matrices for each image
|
58 |
+
:param transforms: Optional [B, 2, 3] image space coordinate transforms
|
59 |
+
:param labels: Optional [B, Res, N] gt labeling
|
60 |
+
:return: [B, Res, N] predictions for each point
|
61 |
+
'''
|
62 |
+
None
|
63 |
+
|
64 |
+
def get_preds(self):
|
65 |
+
'''
|
66 |
+
Get the predictions from the last query
|
67 |
+
:return: [B, Res, N] network prediction for the last query
|
68 |
+
'''
|
69 |
+
return self.preds
|
70 |
+
|
71 |
+
def get_error(self):
|
72 |
+
'''
|
73 |
+
Get the network loss from the last query
|
74 |
+
:return: loss term
|
75 |
+
'''
|
76 |
+
return self.error_term(self.preds, self.labels)
|
PIFu/lib/model/ConvFilters.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.models.resnet as resnet
|
5 |
+
import torchvision.models.vgg as vgg
|
6 |
+
|
7 |
+
|
8 |
+
class MultiConv(nn.Module):
|
9 |
+
def __init__(self, filter_channels):
|
10 |
+
super(MultiConv, self).__init__()
|
11 |
+
self.filters = []
|
12 |
+
|
13 |
+
for l in range(0, len(filter_channels) - 1):
|
14 |
+
self.filters.append(
|
15 |
+
nn.Conv2d(filter_channels[l], filter_channels[l + 1], kernel_size=4, stride=2))
|
16 |
+
self.add_module("conv%d" % l, self.filters[l])
|
17 |
+
|
18 |
+
def forward(self, image):
|
19 |
+
'''
|
20 |
+
:param image: [BxC_inxHxW] tensor of input image
|
21 |
+
:return: list of [BxC_outxHxW] tensors of output features
|
22 |
+
'''
|
23 |
+
y = image
|
24 |
+
# y = F.relu(self.bn0(self.conv0(y)), True)
|
25 |
+
feat_pyramid = [y]
|
26 |
+
for i, f in enumerate(self.filters):
|
27 |
+
y = f(y)
|
28 |
+
if i != len(self.filters) - 1:
|
29 |
+
y = F.leaky_relu(y)
|
30 |
+
# y = F.max_pool2d(y, kernel_size=2, stride=2)
|
31 |
+
feat_pyramid.append(y)
|
32 |
+
return feat_pyramid
|
33 |
+
|
34 |
+
|
35 |
+
class Vgg16(torch.nn.Module):
|
36 |
+
def __init__(self):
|
37 |
+
super(Vgg16, self).__init__()
|
38 |
+
vgg_pretrained_features = vgg.vgg16(pretrained=True).features
|
39 |
+
self.slice1 = torch.nn.Sequential()
|
40 |
+
self.slice2 = torch.nn.Sequential()
|
41 |
+
self.slice3 = torch.nn.Sequential()
|
42 |
+
self.slice4 = torch.nn.Sequential()
|
43 |
+
self.slice5 = torch.nn.Sequential()
|
44 |
+
|
45 |
+
for x in range(4):
|
46 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
47 |
+
for x in range(4, 9):
|
48 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
49 |
+
for x in range(9, 16):
|
50 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
51 |
+
for x in range(16, 23):
|
52 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
53 |
+
for x in range(23, 30):
|
54 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
55 |
+
|
56 |
+
def forward(self, X):
|
57 |
+
h = self.slice1(X)
|
58 |
+
h_relu1_2 = h
|
59 |
+
h = self.slice2(h)
|
60 |
+
h_relu2_2 = h
|
61 |
+
h = self.slice3(h)
|
62 |
+
h_relu3_3 = h
|
63 |
+
h = self.slice4(h)
|
64 |
+
h_relu4_3 = h
|
65 |
+
h = self.slice5(h)
|
66 |
+
h_relu5_3 = h
|
67 |
+
|
68 |
+
return [h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
|
69 |
+
|
70 |
+
|
71 |
+
class ResNet(nn.Module):
|
72 |
+
def __init__(self, model='resnet18'):
|
73 |
+
super(ResNet, self).__init__()
|
74 |
+
|
75 |
+
if model == 'resnet18':
|
76 |
+
net = resnet.resnet18(pretrained=True)
|
77 |
+
elif model == 'resnet34':
|
78 |
+
net = resnet.resnet34(pretrained=True)
|
79 |
+
elif model == 'resnet50':
|
80 |
+
net = resnet.resnet50(pretrained=True)
|
81 |
+
else:
|
82 |
+
raise NameError('Unknown Fan Filter setting!')
|
83 |
+
|
84 |
+
self.conv1 = net.conv1
|
85 |
+
|
86 |
+
self.pool = net.maxpool
|
87 |
+
self.layer0 = nn.Sequential(net.conv1, net.bn1, net.relu)
|
88 |
+
self.layer1 = net.layer1
|
89 |
+
self.layer2 = net.layer2
|
90 |
+
self.layer3 = net.layer3
|
91 |
+
self.layer4 = net.layer4
|
92 |
+
|
93 |
+
def forward(self, image):
|
94 |
+
'''
|
95 |
+
:param image: [BxC_inxHxW] tensor of input image
|
96 |
+
:return: list of [BxC_outxHxW] tensors of output features
|
97 |
+
'''
|
98 |
+
|
99 |
+
y = image
|
100 |
+
feat_pyramid = []
|
101 |
+
y = self.layer0(y)
|
102 |
+
feat_pyramid.append(y)
|
103 |
+
y = self.layer1(self.pool(y))
|
104 |
+
feat_pyramid.append(y)
|
105 |
+
y = self.layer2(y)
|
106 |
+
feat_pyramid.append(y)
|
107 |
+
y = self.layer3(y)
|
108 |
+
feat_pyramid.append(y)
|
109 |
+
y = self.layer4(y)
|
110 |
+
feat_pyramid.append(y)
|
111 |
+
|
112 |
+
return feat_pyramid
|
PIFu/lib/model/ConvPIFuNet.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .BasePIFuNet import BasePIFuNet
|
5 |
+
from .SurfaceClassifier import SurfaceClassifier
|
6 |
+
from .DepthNormalizer import DepthNormalizer
|
7 |
+
from .ConvFilters import *
|
8 |
+
from ..net_util import init_net
|
9 |
+
|
10 |
+
class ConvPIFuNet(BasePIFuNet):
|
11 |
+
'''
|
12 |
+
Conv Piximp network is the standard 3-phase network that we will use.
|
13 |
+
The image filter is a pure multi-layer convolutional network,
|
14 |
+
while during feature extraction phase all features in the pyramid at the projected location
|
15 |
+
will be aggregated.
|
16 |
+
It does the following:
|
17 |
+
1. Compute image feature pyramids and store it in self.im_feat_list
|
18 |
+
2. Calculate calibration and indexing on each of the feat, and append them together
|
19 |
+
3. Classification.
|
20 |
+
'''
|
21 |
+
|
22 |
+
def __init__(self,
|
23 |
+
opt,
|
24 |
+
projection_mode='orthogonal',
|
25 |
+
error_term=nn.MSELoss(),
|
26 |
+
):
|
27 |
+
super(ConvPIFuNet, self).__init__(
|
28 |
+
projection_mode=projection_mode,
|
29 |
+
error_term=error_term)
|
30 |
+
|
31 |
+
self.name = 'convpifu'
|
32 |
+
|
33 |
+
self.opt = opt
|
34 |
+
self.num_views = self.opt.num_views
|
35 |
+
|
36 |
+
self.image_filter = self.define_imagefilter(opt)
|
37 |
+
|
38 |
+
self.surface_classifier = SurfaceClassifier(
|
39 |
+
filter_channels=self.opt.mlp_dim,
|
40 |
+
num_views=self.opt.num_views,
|
41 |
+
no_residual=self.opt.no_residual,
|
42 |
+
last_op=nn.Sigmoid())
|
43 |
+
|
44 |
+
self.normalizer = DepthNormalizer(opt)
|
45 |
+
|
46 |
+
# This is a list of [B x Feat_i x H x W] features
|
47 |
+
self.im_feat_list = []
|
48 |
+
|
49 |
+
init_net(self)
|
50 |
+
|
51 |
+
def define_imagefilter(self, opt):
|
52 |
+
net = None
|
53 |
+
if opt.netIMF == 'multiconv':
|
54 |
+
net = MultiConv(opt.enc_dim)
|
55 |
+
elif 'resnet' in opt.netIMF:
|
56 |
+
net = ResNet(model=opt.netIMF)
|
57 |
+
elif opt.netIMF == 'vgg16':
|
58 |
+
net = Vgg16()
|
59 |
+
else:
|
60 |
+
raise NotImplementedError('model name [%s] is not recognized' % opt.imf_type)
|
61 |
+
|
62 |
+
return net
|
63 |
+
|
64 |
+
def filter(self, images):
|
65 |
+
'''
|
66 |
+
Filter the input images
|
67 |
+
store all intermediate features.
|
68 |
+
:param images: [B, C, H, W] input images
|
69 |
+
'''
|
70 |
+
self.im_feat_list = self.image_filter(images)
|
71 |
+
|
72 |
+
def query(self, points, calibs, transforms=None, labels=None):
|
73 |
+
'''
|
74 |
+
Given 3D points, query the network predictions for each point.
|
75 |
+
Image features should be pre-computed before this call.
|
76 |
+
store all intermediate features.
|
77 |
+
query() function may behave differently during training/testing.
|
78 |
+
:param points: [B, 3, N] world space coordinates of points
|
79 |
+
:param calibs: [B, 3, 4] calibration matrices for each image
|
80 |
+
:param transforms: Optional [B, 2, 3] image space coordinate transforms
|
81 |
+
:param labels: Optional [B, Res, N] gt labeling
|
82 |
+
:return: [B, Res, N] predictions for each point
|
83 |
+
'''
|
84 |
+
if labels is not None:
|
85 |
+
self.labels = labels
|
86 |
+
|
87 |
+
xyz = self.projection(points, calibs, transforms)
|
88 |
+
xy = xyz[:, :2, :]
|
89 |
+
z = xyz[:, 2:3, :]
|
90 |
+
|
91 |
+
z_feat = self.normalizer(z)
|
92 |
+
|
93 |
+
# This is a list of [B, Feat_i, N] features
|
94 |
+
point_local_feat_list = [self.index(im_feat, xy) for im_feat in self.im_feat_list]
|
95 |
+
point_local_feat_list.append(z_feat)
|
96 |
+
# [B, Feat_all, N]
|
97 |
+
point_local_feat = torch.cat(point_local_feat_list, 1)
|
98 |
+
|
99 |
+
self.preds = self.surface_classifier(point_local_feat)
|
PIFu/lib/model/DepthNormalizer.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class DepthNormalizer(nn.Module):
|
7 |
+
def __init__(self, opt):
|
8 |
+
super(DepthNormalizer, self).__init__()
|
9 |
+
self.opt = opt
|
10 |
+
|
11 |
+
def forward(self, z, calibs=None, index_feat=None):
|
12 |
+
'''
|
13 |
+
Normalize z_feature
|
14 |
+
:param z_feat: [B, 1, N] depth value for z in the image coordinate system
|
15 |
+
:return:
|
16 |
+
'''
|
17 |
+
z_feat = z * (self.opt.loadSize // 2) / self.opt.z_size
|
18 |
+
return z_feat
|
PIFu/lib/model/HGFilters.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from ..net_util import *
|
5 |
+
|
6 |
+
|
7 |
+
class HourGlass(nn.Module):
|
8 |
+
def __init__(self, num_modules, depth, num_features, norm='batch'):
|
9 |
+
super(HourGlass, self).__init__()
|
10 |
+
self.num_modules = num_modules
|
11 |
+
self.depth = depth
|
12 |
+
self.features = num_features
|
13 |
+
self.norm = norm
|
14 |
+
|
15 |
+
self._generate_network(self.depth)
|
16 |
+
|
17 |
+
def _generate_network(self, level):
|
18 |
+
self.add_module('b1_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))
|
19 |
+
|
20 |
+
self.add_module('b2_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))
|
21 |
+
|
22 |
+
if level > 1:
|
23 |
+
self._generate_network(level - 1)
|
24 |
+
else:
|
25 |
+
self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))
|
26 |
+
|
27 |
+
self.add_module('b3_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))
|
28 |
+
|
29 |
+
def _forward(self, level, inp):
|
30 |
+
# Upper branch
|
31 |
+
up1 = inp
|
32 |
+
up1 = self._modules['b1_' + str(level)](up1)
|
33 |
+
|
34 |
+
# Lower branch
|
35 |
+
low1 = F.avg_pool2d(inp, 2, stride=2)
|
36 |
+
low1 = self._modules['b2_' + str(level)](low1)
|
37 |
+
|
38 |
+
if level > 1:
|
39 |
+
low2 = self._forward(level - 1, low1)
|
40 |
+
else:
|
41 |
+
low2 = low1
|
42 |
+
low2 = self._modules['b2_plus_' + str(level)](low2)
|
43 |
+
|
44 |
+
low3 = low2
|
45 |
+
low3 = self._modules['b3_' + str(level)](low3)
|
46 |
+
|
47 |
+
# NOTE: for newer PyTorch (1.3~), it seems that training results are degraded due to implementation diff in F.grid_sample
|
48 |
+
# if the pretrained model behaves weirdly, switch with the commented line.
|
49 |
+
# NOTE: I also found that "bicubic" works better.
|
50 |
+
up2 = F.interpolate(low3, scale_factor=2, mode='bicubic', align_corners=True)
|
51 |
+
# up2 = F.interpolate(low3, scale_factor=2, mode='nearest)
|
52 |
+
|
53 |
+
return up1 + up2
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
return self._forward(self.depth, x)
|
57 |
+
|
58 |
+
|
59 |
+
class HGFilter(nn.Module):
|
60 |
+
def __init__(self, opt):
|
61 |
+
super(HGFilter, self).__init__()
|
62 |
+
self.num_modules = opt.num_stack
|
63 |
+
|
64 |
+
self.opt = opt
|
65 |
+
|
66 |
+
# Base part
|
67 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
68 |
+
|
69 |
+
if self.opt.norm == 'batch':
|
70 |
+
self.bn1 = nn.BatchNorm2d(64)
|
71 |
+
elif self.opt.norm == 'group':
|
72 |
+
self.bn1 = nn.GroupNorm(32, 64)
|
73 |
+
|
74 |
+
if self.opt.hg_down == 'conv64':
|
75 |
+
self.conv2 = ConvBlock(64, 64, self.opt.norm)
|
76 |
+
self.down_conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
|
77 |
+
elif self.opt.hg_down == 'conv128':
|
78 |
+
self.conv2 = ConvBlock(64, 128, self.opt.norm)
|
79 |
+
self.down_conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)
|
80 |
+
elif self.opt.hg_down == 'ave_pool':
|
81 |
+
self.conv2 = ConvBlock(64, 128, self.opt.norm)
|
82 |
+
else:
|
83 |
+
raise NameError('Unknown Fan Filter setting!')
|
84 |
+
|
85 |
+
self.conv3 = ConvBlock(128, 128, self.opt.norm)
|
86 |
+
self.conv4 = ConvBlock(128, 256, self.opt.norm)
|
87 |
+
|
88 |
+
# Stacking part
|
89 |
+
for hg_module in range(self.num_modules):
|
90 |
+
self.add_module('m' + str(hg_module), HourGlass(1, opt.num_hourglass, 256, self.opt.norm))
|
91 |
+
|
92 |
+
self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256, self.opt.norm))
|
93 |
+
self.add_module('conv_last' + str(hg_module),
|
94 |
+
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
95 |
+
if self.opt.norm == 'batch':
|
96 |
+
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
|
97 |
+
elif self.opt.norm == 'group':
|
98 |
+
self.add_module('bn_end' + str(hg_module), nn.GroupNorm(32, 256))
|
99 |
+
|
100 |
+
self.add_module('l' + str(hg_module), nn.Conv2d(256,
|
101 |
+
opt.hourglass_dim, kernel_size=1, stride=1, padding=0))
|
102 |
+
|
103 |
+
if hg_module < self.num_modules - 1:
|
104 |
+
self.add_module(
|
105 |
+
'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
106 |
+
self.add_module('al' + str(hg_module), nn.Conv2d(opt.hourglass_dim,
|
107 |
+
256, kernel_size=1, stride=1, padding=0))
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
x = F.relu(self.bn1(self.conv1(x)), True)
|
111 |
+
tmpx = x
|
112 |
+
if self.opt.hg_down == 'ave_pool':
|
113 |
+
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
|
114 |
+
elif self.opt.hg_down in ['conv64', 'conv128']:
|
115 |
+
x = self.conv2(x)
|
116 |
+
x = self.down_conv2(x)
|
117 |
+
else:
|
118 |
+
raise NameError('Unknown Fan Filter setting!')
|
119 |
+
|
120 |
+
normx = x
|
121 |
+
|
122 |
+
x = self.conv3(x)
|
123 |
+
x = self.conv4(x)
|
124 |
+
|
125 |
+
previous = x
|
126 |
+
|
127 |
+
outputs = []
|
128 |
+
for i in range(self.num_modules):
|
129 |
+
hg = self._modules['m' + str(i)](previous)
|
130 |
+
|
131 |
+
ll = hg
|
132 |
+
ll = self._modules['top_m_' + str(i)](ll)
|
133 |
+
|
134 |
+
ll = F.relu(self._modules['bn_end' + str(i)]
|
135 |
+
(self._modules['conv_last' + str(i)](ll)), True)
|
136 |
+
|
137 |
+
# Predict heatmaps
|
138 |
+
tmp_out = self._modules['l' + str(i)](ll)
|
139 |
+
outputs.append(tmp_out)
|
140 |
+
|
141 |
+
if i < self.num_modules - 1:
|
142 |
+
ll = self._modules['bl' + str(i)](ll)
|
143 |
+
tmp_out_ = self._modules['al' + str(i)](tmp_out)
|
144 |
+
previous = previous + ll + tmp_out_
|
145 |
+
|
146 |
+
return outputs, tmpx.detach(), normx
|
PIFu/lib/model/HGPIFuNet.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .BasePIFuNet import BasePIFuNet
|
5 |
+
from .SurfaceClassifier import SurfaceClassifier
|
6 |
+
from .DepthNormalizer import DepthNormalizer
|
7 |
+
from .HGFilters import *
|
8 |
+
from ..net_util import init_net
|
9 |
+
|
10 |
+
|
11 |
+
class HGPIFuNet(BasePIFuNet):
|
12 |
+
'''
|
13 |
+
HG PIFu network uses Hourglass stacks as the image filter.
|
14 |
+
It does the following:
|
15 |
+
1. Compute image feature stacks and store it in self.im_feat_list
|
16 |
+
self.im_feat_list[-1] is the last stack (output stack)
|
17 |
+
2. Calculate calibration
|
18 |
+
3. If training, it index on every intermediate stacks,
|
19 |
+
If testing, it index on the last stack.
|
20 |
+
4. Classification.
|
21 |
+
5. During training, error is calculated on all stacks.
|
22 |
+
'''
|
23 |
+
|
24 |
+
def __init__(self,
|
25 |
+
opt,
|
26 |
+
projection_mode='orthogonal',
|
27 |
+
error_term=nn.MSELoss(),
|
28 |
+
):
|
29 |
+
super(HGPIFuNet, self).__init__(
|
30 |
+
projection_mode=projection_mode,
|
31 |
+
error_term=error_term)
|
32 |
+
|
33 |
+
self.name = 'hgpifu'
|
34 |
+
|
35 |
+
self.opt = opt
|
36 |
+
self.num_views = self.opt.num_views
|
37 |
+
|
38 |
+
self.image_filter = HGFilter(opt)
|
39 |
+
|
40 |
+
self.surface_classifier = SurfaceClassifier(
|
41 |
+
filter_channels=self.opt.mlp_dim,
|
42 |
+
num_views=self.opt.num_views,
|
43 |
+
no_residual=self.opt.no_residual,
|
44 |
+
last_op=nn.Sigmoid())
|
45 |
+
|
46 |
+
self.normalizer = DepthNormalizer(opt)
|
47 |
+
|
48 |
+
# This is a list of [B x Feat_i x H x W] features
|
49 |
+
self.im_feat_list = []
|
50 |
+
self.tmpx = None
|
51 |
+
self.normx = None
|
52 |
+
|
53 |
+
self.intermediate_preds_list = []
|
54 |
+
|
55 |
+
init_net(self)
|
56 |
+
|
57 |
+
def filter(self, images):
|
58 |
+
'''
|
59 |
+
Filter the input images
|
60 |
+
store all intermediate features.
|
61 |
+
:param images: [B, C, H, W] input images
|
62 |
+
'''
|
63 |
+
self.im_feat_list, self.tmpx, self.normx = self.image_filter(images)
|
64 |
+
# If it is not in training, only produce the last im_feat
|
65 |
+
if not self.training:
|
66 |
+
self.im_feat_list = [self.im_feat_list[-1]]
|
67 |
+
|
68 |
+
def query(self, points, calibs, transforms=None, labels=None):
|
69 |
+
'''
|
70 |
+
Given 3D points, query the network predictions for each point.
|
71 |
+
Image features should be pre-computed before this call.
|
72 |
+
store all intermediate features.
|
73 |
+
query() function may behave differently during training/testing.
|
74 |
+
:param points: [B, 3, N] world space coordinates of points
|
75 |
+
:param calibs: [B, 3, 4] calibration matrices for each image
|
76 |
+
:param transforms: Optional [B, 2, 3] image space coordinate transforms
|
77 |
+
:param labels: Optional [B, Res, N] gt labeling
|
78 |
+
:return: [B, Res, N] predictions for each point
|
79 |
+
'''
|
80 |
+
if labels is not None:
|
81 |
+
self.labels = labels
|
82 |
+
|
83 |
+
xyz = self.projection(points, calibs, transforms)
|
84 |
+
xy = xyz[:, :2, :]
|
85 |
+
z = xyz[:, 2:3, :]
|
86 |
+
|
87 |
+
in_img = (xy[:, 0] >= -1.0) & (xy[:, 0] <= 1.0) & (xy[:, 1] >= -1.0) & (xy[:, 1] <= 1.0)
|
88 |
+
|
89 |
+
z_feat = self.normalizer(z, calibs=calibs)
|
90 |
+
|
91 |
+
if self.opt.skip_hourglass:
|
92 |
+
tmpx_local_feature = self.index(self.tmpx, xy)
|
93 |
+
|
94 |
+
self.intermediate_preds_list = []
|
95 |
+
|
96 |
+
for im_feat in self.im_feat_list:
|
97 |
+
# [B, Feat_i + z, N]
|
98 |
+
point_local_feat_list = [self.index(im_feat, xy), z_feat]
|
99 |
+
|
100 |
+
if self.opt.skip_hourglass:
|
101 |
+
point_local_feat_list.append(tmpx_local_feature)
|
102 |
+
|
103 |
+
point_local_feat = torch.cat(point_local_feat_list, 1)
|
104 |
+
|
105 |
+
# out of image plane is always set to 0
|
106 |
+
pred = in_img[:,None].float() * self.surface_classifier(point_local_feat)
|
107 |
+
self.intermediate_preds_list.append(pred)
|
108 |
+
|
109 |
+
self.preds = self.intermediate_preds_list[-1]
|
110 |
+
|
111 |
+
def get_im_feat(self):
|
112 |
+
'''
|
113 |
+
Get the image filter
|
114 |
+
:return: [B, C_feat, H, W] image feature after filtering
|
115 |
+
'''
|
116 |
+
return self.im_feat_list[-1]
|
117 |
+
|
118 |
+
def get_error(self):
|
119 |
+
'''
|
120 |
+
Hourglass has its own intermediate supervision scheme
|
121 |
+
'''
|
122 |
+
error = 0
|
123 |
+
for preds in self.intermediate_preds_list:
|
124 |
+
error += self.error_term(preds, self.labels)
|
125 |
+
error /= len(self.intermediate_preds_list)
|
126 |
+
|
127 |
+
return error
|
128 |
+
|
129 |
+
def forward(self, images, points, calibs, transforms=None, labels=None):
|
130 |
+
# Get image feature
|
131 |
+
self.filter(images)
|
132 |
+
|
133 |
+
# Phase 2: point query
|
134 |
+
self.query(points=points, calibs=calibs, transforms=transforms, labels=labels)
|
135 |
+
|
136 |
+
# get the prediction
|
137 |
+
res = self.get_preds()
|
138 |
+
|
139 |
+
# get the error
|
140 |
+
error = self.get_error()
|
141 |
+
|
142 |
+
return res, error
|
PIFu/lib/model/ResBlkPIFuNet.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .BasePIFuNet import BasePIFuNet
|
5 |
+
import functools
|
6 |
+
from .SurfaceClassifier import SurfaceClassifier
|
7 |
+
from .DepthNormalizer import DepthNormalizer
|
8 |
+
from ..net_util import *
|
9 |
+
|
10 |
+
|
11 |
+
class ResBlkPIFuNet(BasePIFuNet):
|
12 |
+
def __init__(self, opt,
|
13 |
+
projection_mode='orthogonal'):
|
14 |
+
if opt.color_loss_type == 'l1':
|
15 |
+
error_term = nn.L1Loss()
|
16 |
+
elif opt.color_loss_type == 'mse':
|
17 |
+
error_term = nn.MSELoss()
|
18 |
+
|
19 |
+
super(ResBlkPIFuNet, self).__init__(
|
20 |
+
projection_mode=projection_mode,
|
21 |
+
error_term=error_term)
|
22 |
+
|
23 |
+
self.name = 'respifu'
|
24 |
+
self.opt = opt
|
25 |
+
|
26 |
+
norm_type = get_norm_layer(norm_type=opt.norm_color)
|
27 |
+
self.image_filter = ResnetFilter(opt, norm_layer=norm_type)
|
28 |
+
|
29 |
+
self.surface_classifier = SurfaceClassifier(
|
30 |
+
filter_channels=self.opt.mlp_dim_color,
|
31 |
+
num_views=self.opt.num_views,
|
32 |
+
no_residual=self.opt.no_residual,
|
33 |
+
last_op=nn.Tanh())
|
34 |
+
|
35 |
+
self.normalizer = DepthNormalizer(opt)
|
36 |
+
|
37 |
+
init_net(self)
|
38 |
+
|
39 |
+
def filter(self, images):
|
40 |
+
'''
|
41 |
+
Filter the input images
|
42 |
+
store all intermediate features.
|
43 |
+
:param images: [B, C, H, W] input images
|
44 |
+
'''
|
45 |
+
self.im_feat = self.image_filter(images)
|
46 |
+
|
47 |
+
def attach(self, im_feat):
|
48 |
+
self.im_feat = torch.cat([im_feat, self.im_feat], 1)
|
49 |
+
|
50 |
+
def query(self, points, calibs, transforms=None, labels=None):
|
51 |
+
'''
|
52 |
+
Given 3D points, query the network predictions for each point.
|
53 |
+
Image features should be pre-computed before this call.
|
54 |
+
store all intermediate features.
|
55 |
+
query() function may behave differently during training/testing.
|
56 |
+
:param points: [B, 3, N] world space coordinates of points
|
57 |
+
:param calibs: [B, 3, 4] calibration matrices for each image
|
58 |
+
:param transforms: Optional [B, 2, 3] image space coordinate transforms
|
59 |
+
:param labels: Optional [B, Res, N] gt labeling
|
60 |
+
:return: [B, Res, N] predictions for each point
|
61 |
+
'''
|
62 |
+
if labels is not None:
|
63 |
+
self.labels = labels
|
64 |
+
|
65 |
+
xyz = self.projection(points, calibs, transforms)
|
66 |
+
xy = xyz[:, :2, :]
|
67 |
+
z = xyz[:, 2:3, :]
|
68 |
+
|
69 |
+
z_feat = self.normalizer(z)
|
70 |
+
|
71 |
+
# This is a list of [B, Feat_i, N] features
|
72 |
+
point_local_feat_list = [self.index(self.im_feat, xy), z_feat]
|
73 |
+
# [B, Feat_all, N]
|
74 |
+
point_local_feat = torch.cat(point_local_feat_list, 1)
|
75 |
+
|
76 |
+
self.preds = self.surface_classifier(point_local_feat)
|
77 |
+
|
78 |
+
def forward(self, images, im_feat, points, calibs, transforms=None, labels=None):
|
79 |
+
self.filter(images)
|
80 |
+
|
81 |
+
self.attach(im_feat)
|
82 |
+
|
83 |
+
self.query(points, calibs, transforms, labels)
|
84 |
+
|
85 |
+
res = self.get_preds()
|
86 |
+
error = self.get_error()
|
87 |
+
|
88 |
+
return res, error
|
89 |
+
|
90 |
+
class ResnetBlock(nn.Module):
|
91 |
+
"""Define a Resnet block"""
|
92 |
+
|
93 |
+
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, last=False):
|
94 |
+
"""Initialize the Resnet block
|
95 |
+
A resnet block is a conv block with skip connections
|
96 |
+
We construct a conv block with build_conv_block function,
|
97 |
+
and implement skip connections in <forward> function.
|
98 |
+
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
99 |
+
"""
|
100 |
+
super(ResnetBlock, self).__init__()
|
101 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, last)
|
102 |
+
|
103 |
+
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, last=False):
|
104 |
+
"""Construct a convolutional block.
|
105 |
+
Parameters:
|
106 |
+
dim (int) -- the number of channels in the conv layer.
|
107 |
+
padding_type (str) -- the name of padding layer: reflect | replicate | zero
|
108 |
+
norm_layer -- normalization layer
|
109 |
+
use_dropout (bool) -- if use dropout layers.
|
110 |
+
use_bias (bool) -- if the conv layer uses bias or not
|
111 |
+
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
|
112 |
+
"""
|
113 |
+
conv_block = []
|
114 |
+
p = 0
|
115 |
+
if padding_type == 'reflect':
|
116 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
117 |
+
elif padding_type == 'replicate':
|
118 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
119 |
+
elif padding_type == 'zero':
|
120 |
+
p = 1
|
121 |
+
else:
|
122 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
123 |
+
|
124 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
|
125 |
+
if use_dropout:
|
126 |
+
conv_block += [nn.Dropout(0.5)]
|
127 |
+
|
128 |
+
p = 0
|
129 |
+
if padding_type == 'reflect':
|
130 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
131 |
+
elif padding_type == 'replicate':
|
132 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
133 |
+
elif padding_type == 'zero':
|
134 |
+
p = 1
|
135 |
+
else:
|
136 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
137 |
+
if last:
|
138 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
|
139 |
+
else:
|
140 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
|
141 |
+
|
142 |
+
return nn.Sequential(*conv_block)
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
"""Forward function (with skip connections)"""
|
146 |
+
out = x + self.conv_block(x) # add skip connections
|
147 |
+
return out
|
148 |
+
|
149 |
+
|
150 |
+
class ResnetFilter(nn.Module):
|
151 |
+
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
152 |
+
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
153 |
+
"""
|
154 |
+
|
155 |
+
def __init__(self, opt, input_nc=3, output_nc=256, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False,
|
156 |
+
n_blocks=6, padding_type='reflect'):
|
157 |
+
"""Construct a Resnet-based generator
|
158 |
+
Parameters:
|
159 |
+
input_nc (int) -- the number of channels in input images
|
160 |
+
output_nc (int) -- the number of channels in output images
|
161 |
+
ngf (int) -- the number of filters in the last conv layer
|
162 |
+
norm_layer -- normalization layer
|
163 |
+
use_dropout (bool) -- if use dropout layers
|
164 |
+
n_blocks (int) -- the number of ResNet blocks
|
165 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
166 |
+
"""
|
167 |
+
assert (n_blocks >= 0)
|
168 |
+
super(ResnetFilter, self).__init__()
|
169 |
+
if type(norm_layer) == functools.partial:
|
170 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
171 |
+
else:
|
172 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
173 |
+
|
174 |
+
model = [nn.ReflectionPad2d(3),
|
175 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
176 |
+
norm_layer(ngf),
|
177 |
+
nn.ReLU(True)]
|
178 |
+
|
179 |
+
n_downsampling = 2
|
180 |
+
for i in range(n_downsampling): # add downsampling layers
|
181 |
+
mult = 2 ** i
|
182 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
183 |
+
norm_layer(ngf * mult * 2),
|
184 |
+
nn.ReLU(True)]
|
185 |
+
|
186 |
+
mult = 2 ** n_downsampling
|
187 |
+
for i in range(n_blocks): # add ResNet blocks
|
188 |
+
if i == n_blocks - 1:
|
189 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer,
|
190 |
+
use_dropout=use_dropout, use_bias=use_bias, last=True)]
|
191 |
+
else:
|
192 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer,
|
193 |
+
use_dropout=use_dropout, use_bias=use_bias)]
|
194 |
+
|
195 |
+
if opt.use_tanh:
|
196 |
+
model += [nn.Tanh()]
|
197 |
+
self.model = nn.Sequential(*model)
|
198 |
+
|
199 |
+
def forward(self, input):
|
200 |
+
"""Standard forward"""
|
201 |
+
return self.model(input)
|
PIFu/lib/model/SurfaceClassifier.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class SurfaceClassifier(nn.Module):
|
7 |
+
def __init__(self, filter_channels, num_views=1, no_residual=True, last_op=None):
|
8 |
+
super(SurfaceClassifier, self).__init__()
|
9 |
+
|
10 |
+
self.filters = []
|
11 |
+
self.num_views = num_views
|
12 |
+
self.no_residual = no_residual
|
13 |
+
filter_channels = filter_channels
|
14 |
+
self.last_op = last_op
|
15 |
+
|
16 |
+
if self.no_residual:
|
17 |
+
for l in range(0, len(filter_channels) - 1):
|
18 |
+
self.filters.append(nn.Conv1d(
|
19 |
+
filter_channels[l],
|
20 |
+
filter_channels[l + 1],
|
21 |
+
1))
|
22 |
+
self.add_module("conv%d" % l, self.filters[l])
|
23 |
+
else:
|
24 |
+
for l in range(0, len(filter_channels) - 1):
|
25 |
+
if 0 != l:
|
26 |
+
self.filters.append(
|
27 |
+
nn.Conv1d(
|
28 |
+
filter_channels[l] + filter_channels[0],
|
29 |
+
filter_channels[l + 1],
|
30 |
+
1))
|
31 |
+
else:
|
32 |
+
self.filters.append(nn.Conv1d(
|
33 |
+
filter_channels[l],
|
34 |
+
filter_channels[l + 1],
|
35 |
+
1))
|
36 |
+
|
37 |
+
self.add_module("conv%d" % l, self.filters[l])
|
38 |
+
|
39 |
+
def forward(self, feature):
|
40 |
+
'''
|
41 |
+
|
42 |
+
:param feature: list of [BxC_inxHxW] tensors of image features
|
43 |
+
:param xy: [Bx3xN] tensor of (x,y) coodinates in the image plane
|
44 |
+
:return: [BxC_outxN] tensor of features extracted at the coordinates
|
45 |
+
'''
|
46 |
+
|
47 |
+
y = feature
|
48 |
+
tmpy = feature
|
49 |
+
for i, f in enumerate(self.filters):
|
50 |
+
if self.no_residual:
|
51 |
+
y = self._modules['conv' + str(i)](y)
|
52 |
+
else:
|
53 |
+
y = self._modules['conv' + str(i)](
|
54 |
+
y if i == 0
|
55 |
+
else torch.cat([y, tmpy], 1)
|
56 |
+
)
|
57 |
+
if i != len(self.filters) - 1:
|
58 |
+
y = F.leaky_relu(y)
|
59 |
+
|
60 |
+
if self.num_views > 1 and i == len(self.filters) // 2:
|
61 |
+
y = y.view(
|
62 |
+
-1, self.num_views, y.shape[1], y.shape[2]
|
63 |
+
).mean(dim=1)
|
64 |
+
tmpy = feature.view(
|
65 |
+
-1, self.num_views, feature.shape[1], feature.shape[2]
|
66 |
+
).mean(dim=1)
|
67 |
+
|
68 |
+
if self.last_op:
|
69 |
+
y = self.last_op(y)
|
70 |
+
|
71 |
+
return y
|
PIFu/lib/model/VhullPIFuNet.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .BasePIFuNet import BasePIFuNet
|
5 |
+
|
6 |
+
|
7 |
+
class VhullPIFuNet(BasePIFuNet):
|
8 |
+
'''
|
9 |
+
Vhull Piximp network is a minimal network demonstrating how the template works
|
10 |
+
also, it helps debugging the training/test schemes
|
11 |
+
It does the following:
|
12 |
+
1. Compute the masks of images and stores under self.im_feats
|
13 |
+
2. Calculate calibration and indexing
|
14 |
+
3. Return if the points fall into the intersection of all masks
|
15 |
+
'''
|
16 |
+
|
17 |
+
def __init__(self,
|
18 |
+
num_views,
|
19 |
+
projection_mode='orthogonal',
|
20 |
+
error_term=nn.MSELoss(),
|
21 |
+
):
|
22 |
+
super(VhullPIFuNet, self).__init__(
|
23 |
+
projection_mode=projection_mode,
|
24 |
+
error_term=error_term)
|
25 |
+
self.name = 'vhull'
|
26 |
+
|
27 |
+
self.num_views = num_views
|
28 |
+
|
29 |
+
self.im_feat = None
|
30 |
+
|
31 |
+
def filter(self, images):
|
32 |
+
'''
|
33 |
+
Filter the input images
|
34 |
+
store all intermediate features.
|
35 |
+
:param images: [B, C, H, W] input images
|
36 |
+
'''
|
37 |
+
# If the image has alpha channel, use the alpha channel
|
38 |
+
if images.shape[1] > 3:
|
39 |
+
self.im_feat = images[:, 3:4, :, :]
|
40 |
+
# Else, tell if it's not white
|
41 |
+
else:
|
42 |
+
self.im_feat = images[:, 0:1, :, :]
|
43 |
+
|
44 |
+
def query(self, points, calibs, transforms=None, labels=None):
|
45 |
+
'''
|
46 |
+
Given 3D points, query the network predictions for each point.
|
47 |
+
Image features should be pre-computed before this call.
|
48 |
+
store all intermediate features.
|
49 |
+
query() function may behave differently during training/testing.
|
50 |
+
:param points: [B, 3, N] world space coordinates of points
|
51 |
+
:param calibs: [B, 3, 4] calibration matrices for each image
|
52 |
+
:param transforms: Optional [B, 2, 3] image space coordinate transforms
|
53 |
+
:param labels: Optional [B, Res, N] gt labeling
|
54 |
+
:return: [B, Res, N] predictions for each point
|
55 |
+
'''
|
56 |
+
if labels is not None:
|
57 |
+
self.labels = labels
|
58 |
+
|
59 |
+
xyz = self.projection(points, calibs, transforms)
|
60 |
+
xy = xyz[:, :2, :]
|
61 |
+
|
62 |
+
point_local_feat = self.index(self.im_feat, xy)
|
63 |
+
local_shape = point_local_feat.shape
|
64 |
+
point_feat = point_local_feat.view(
|
65 |
+
local_shape[0] // self.num_views,
|
66 |
+
local_shape[1] * self.num_views,
|
67 |
+
-1)
|
68 |
+
pred = torch.prod(point_feat, dim=1)
|
69 |
+
|
70 |
+
self.preds = pred.unsqueeze(1)
|
PIFu/lib/model/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .BasePIFuNet import BasePIFuNet
|
2 |
+
from .VhullPIFuNet import VhullPIFuNet
|
3 |
+
from .ConvPIFuNet import ConvPIFuNet
|
4 |
+
from .HGPIFuNet import HGPIFuNet
|
5 |
+
from .ResBlkPIFuNet import ResBlkPIFuNet
|
PIFu/lib/net_util.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import init
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import functools
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from .mesh_util import *
|
9 |
+
from .sample_util import *
|
10 |
+
from .geometry import index
|
11 |
+
import cv2
|
12 |
+
from PIL import Image
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
def reshape_multiview_tensors(image_tensor, calib_tensor):
|
17 |
+
# Careful here! Because we put single view and multiview together,
|
18 |
+
# the returned tensor.shape is 5-dim: [B, num_views, C, W, H]
|
19 |
+
# So we need to convert it back to 4-dim [B*num_views, C, W, H]
|
20 |
+
# Don't worry classifier will handle multi-view cases
|
21 |
+
image_tensor = image_tensor.view(
|
22 |
+
image_tensor.shape[0] * image_tensor.shape[1],
|
23 |
+
image_tensor.shape[2],
|
24 |
+
image_tensor.shape[3],
|
25 |
+
image_tensor.shape[4]
|
26 |
+
)
|
27 |
+
calib_tensor = calib_tensor.view(
|
28 |
+
calib_tensor.shape[0] * calib_tensor.shape[1],
|
29 |
+
calib_tensor.shape[2],
|
30 |
+
calib_tensor.shape[3]
|
31 |
+
)
|
32 |
+
|
33 |
+
return image_tensor, calib_tensor
|
34 |
+
|
35 |
+
|
36 |
+
def reshape_sample_tensor(sample_tensor, num_views):
|
37 |
+
if num_views == 1:
|
38 |
+
return sample_tensor
|
39 |
+
# Need to repeat sample_tensor along the batch dim num_views times
|
40 |
+
sample_tensor = sample_tensor.unsqueeze(dim=1)
|
41 |
+
sample_tensor = sample_tensor.repeat(1, num_views, 1, 1)
|
42 |
+
sample_tensor = sample_tensor.view(
|
43 |
+
sample_tensor.shape[0] * sample_tensor.shape[1],
|
44 |
+
sample_tensor.shape[2],
|
45 |
+
sample_tensor.shape[3]
|
46 |
+
)
|
47 |
+
return sample_tensor
|
48 |
+
|
49 |
+
|
50 |
+
def gen_mesh(opt, net, cuda, data, save_path, use_octree=True):
|
51 |
+
image_tensor = data['img'].to(device=cuda)
|
52 |
+
calib_tensor = data['calib'].to(device=cuda)
|
53 |
+
|
54 |
+
net.filter(image_tensor)
|
55 |
+
|
56 |
+
b_min = data['b_min']
|
57 |
+
b_max = data['b_max']
|
58 |
+
try:
|
59 |
+
save_img_path = save_path[:-4] + '.png'
|
60 |
+
save_img_list = []
|
61 |
+
for v in range(image_tensor.shape[0]):
|
62 |
+
save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(), (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0
|
63 |
+
save_img_list.append(save_img)
|
64 |
+
save_img = np.concatenate(save_img_list, axis=1)
|
65 |
+
Image.fromarray(np.uint8(save_img[:,:,::-1])).save(save_img_path)
|
66 |
+
|
67 |
+
verts, faces, _, _ = reconstruction(
|
68 |
+
net, cuda, calib_tensor, opt.resolution, b_min, b_max, use_octree=use_octree)
|
69 |
+
verts_tensor = torch.from_numpy(verts.T).unsqueeze(0).to(device=cuda).float()
|
70 |
+
xyz_tensor = net.projection(verts_tensor, calib_tensor[:1])
|
71 |
+
uv = xyz_tensor[:, :2, :]
|
72 |
+
color = index(image_tensor[:1], uv).detach().cpu().numpy()[0].T
|
73 |
+
color = color * 0.5 + 0.5
|
74 |
+
save_obj_mesh_with_color(save_path, verts, faces, color)
|
75 |
+
except Exception as e:
|
76 |
+
print(e)
|
77 |
+
print('Can not create marching cubes at this time.')
|
78 |
+
|
79 |
+
def gen_mesh_color(opt, netG, netC, cuda, data, save_path, use_octree=True):
|
80 |
+
image_tensor = data['img'].to(device=cuda)
|
81 |
+
calib_tensor = data['calib'].to(device=cuda)
|
82 |
+
|
83 |
+
netG.filter(image_tensor)
|
84 |
+
netC.filter(image_tensor)
|
85 |
+
netC.attach(netG.get_im_feat())
|
86 |
+
|
87 |
+
b_min = data['b_min']
|
88 |
+
b_max = data['b_max']
|
89 |
+
try:
|
90 |
+
save_img_path = save_path[:-4] + '.png'
|
91 |
+
save_img_list = []
|
92 |
+
for v in range(image_tensor.shape[0]):
|
93 |
+
save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(), (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0
|
94 |
+
save_img_list.append(save_img)
|
95 |
+
save_img = np.concatenate(save_img_list, axis=1)
|
96 |
+
Image.fromarray(np.uint8(save_img[:,:,::-1])).save(save_img_path)
|
97 |
+
|
98 |
+
verts, faces, _, _ = reconstruction(
|
99 |
+
netG, cuda, calib_tensor, opt.resolution, b_min, b_max, use_octree=use_octree)
|
100 |
+
|
101 |
+
# Now Getting colors
|
102 |
+
verts_tensor = torch.from_numpy(verts.T).unsqueeze(0).to(device=cuda).float()
|
103 |
+
verts_tensor = reshape_sample_tensor(verts_tensor, opt.num_views)
|
104 |
+
|
105 |
+
color = np.zeros(verts.shape)
|
106 |
+
interval = opt.num_sample_color
|
107 |
+
for i in range(len(color) // interval):
|
108 |
+
left = i * interval
|
109 |
+
right = i * interval + interval
|
110 |
+
if i == len(color) // interval - 1:
|
111 |
+
right = -1
|
112 |
+
netC.query(verts_tensor[:, :, left:right], calib_tensor)
|
113 |
+
rgb = netC.get_preds()[0].detach().cpu().numpy() * 0.5 + 0.5
|
114 |
+
color[left:right] = rgb.T
|
115 |
+
|
116 |
+
save_obj_mesh_with_color(save_path, verts, faces, color)
|
117 |
+
except Exception as e:
|
118 |
+
print(e)
|
119 |
+
print('Can not create marching cubes at this time.')
|
120 |
+
|
121 |
+
def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma):
|
122 |
+
"""Sets the learning rate to the initial LR decayed by schedule"""
|
123 |
+
if epoch in schedule:
|
124 |
+
lr *= gamma
|
125 |
+
for param_group in optimizer.param_groups:
|
126 |
+
param_group['lr'] = lr
|
127 |
+
return lr
|
128 |
+
|
129 |
+
|
130 |
+
def compute_acc(pred, gt, thresh=0.5):
|
131 |
+
'''
|
132 |
+
return:
|
133 |
+
IOU, precision, and recall
|
134 |
+
'''
|
135 |
+
with torch.no_grad():
|
136 |
+
vol_pred = pred > thresh
|
137 |
+
vol_gt = gt > thresh
|
138 |
+
|
139 |
+
union = vol_pred | vol_gt
|
140 |
+
inter = vol_pred & vol_gt
|
141 |
+
|
142 |
+
true_pos = inter.sum().float()
|
143 |
+
|
144 |
+
union = union.sum().float()
|
145 |
+
if union == 0:
|
146 |
+
union = 1
|
147 |
+
vol_pred = vol_pred.sum().float()
|
148 |
+
if vol_pred == 0:
|
149 |
+
vol_pred = 1
|
150 |
+
vol_gt = vol_gt.sum().float()
|
151 |
+
if vol_gt == 0:
|
152 |
+
vol_gt = 1
|
153 |
+
return true_pos / union, true_pos / vol_pred, true_pos / vol_gt
|
154 |
+
|
155 |
+
|
156 |
+
def calc_error(opt, net, cuda, dataset, num_tests):
|
157 |
+
if num_tests > len(dataset):
|
158 |
+
num_tests = len(dataset)
|
159 |
+
with torch.no_grad():
|
160 |
+
erorr_arr, IOU_arr, prec_arr, recall_arr = [], [], [], []
|
161 |
+
for idx in tqdm(range(num_tests)):
|
162 |
+
data = dataset[idx * len(dataset) // num_tests]
|
163 |
+
# retrieve the data
|
164 |
+
image_tensor = data['img'].to(device=cuda)
|
165 |
+
calib_tensor = data['calib'].to(device=cuda)
|
166 |
+
sample_tensor = data['samples'].to(device=cuda).unsqueeze(0)
|
167 |
+
if opt.num_views > 1:
|
168 |
+
sample_tensor = reshape_sample_tensor(sample_tensor, opt.num_views)
|
169 |
+
label_tensor = data['labels'].to(device=cuda).unsqueeze(0)
|
170 |
+
|
171 |
+
res, error = net.forward(image_tensor, sample_tensor, calib_tensor, labels=label_tensor)
|
172 |
+
|
173 |
+
IOU, prec, recall = compute_acc(res, label_tensor)
|
174 |
+
|
175 |
+
# print(
|
176 |
+
# '{0}/{1} | Error: {2:06f} IOU: {3:06f} prec: {4:06f} recall: {5:06f}'
|
177 |
+
# .format(idx, num_tests, error.item(), IOU.item(), prec.item(), recall.item()))
|
178 |
+
erorr_arr.append(error.item())
|
179 |
+
IOU_arr.append(IOU.item())
|
180 |
+
prec_arr.append(prec.item())
|
181 |
+
recall_arr.append(recall.item())
|
182 |
+
|
183 |
+
return np.average(erorr_arr), np.average(IOU_arr), np.average(prec_arr), np.average(recall_arr)
|
184 |
+
|
185 |
+
def calc_error_color(opt, netG, netC, cuda, dataset, num_tests):
|
186 |
+
if num_tests > len(dataset):
|
187 |
+
num_tests = len(dataset)
|
188 |
+
with torch.no_grad():
|
189 |
+
error_color_arr = []
|
190 |
+
|
191 |
+
for idx in tqdm(range(num_tests)):
|
192 |
+
data = dataset[idx * len(dataset) // num_tests]
|
193 |
+
# retrieve the data
|
194 |
+
image_tensor = data['img'].to(device=cuda)
|
195 |
+
calib_tensor = data['calib'].to(device=cuda)
|
196 |
+
color_sample_tensor = data['color_samples'].to(device=cuda).unsqueeze(0)
|
197 |
+
|
198 |
+
if opt.num_views > 1:
|
199 |
+
color_sample_tensor = reshape_sample_tensor(color_sample_tensor, opt.num_views)
|
200 |
+
|
201 |
+
rgb_tensor = data['rgbs'].to(device=cuda).unsqueeze(0)
|
202 |
+
|
203 |
+
netG.filter(image_tensor)
|
204 |
+
_, errorC = netC.forward(image_tensor, netG.get_im_feat(), color_sample_tensor, calib_tensor, labels=rgb_tensor)
|
205 |
+
|
206 |
+
# print('{0}/{1} | Error inout: {2:06f} | Error color: {3:06f}'
|
207 |
+
# .format(idx, num_tests, errorG.item(), errorC.item()))
|
208 |
+
error_color_arr.append(errorC.item())
|
209 |
+
|
210 |
+
return np.average(error_color_arr)
|
211 |
+
|
212 |
+
|
213 |
+
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
|
214 |
+
"3x3 convolution with padding"
|
215 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
216 |
+
stride=strd, padding=padding, bias=bias)
|
217 |
+
|
218 |
+
def init_weights(net, init_type='normal', init_gain=0.02):
|
219 |
+
"""Initialize network weights.
|
220 |
+
|
221 |
+
Parameters:
|
222 |
+
net (network) -- network to be initialized
|
223 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
224 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
225 |
+
|
226 |
+
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
227 |
+
work better for some applications. Feel free to try yourself.
|
228 |
+
"""
|
229 |
+
|
230 |
+
def init_func(m): # define the initialization function
|
231 |
+
classname = m.__class__.__name__
|
232 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
233 |
+
if init_type == 'normal':
|
234 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
235 |
+
elif init_type == 'xavier':
|
236 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
237 |
+
elif init_type == 'kaiming':
|
238 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
239 |
+
elif init_type == 'orthogonal':
|
240 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
241 |
+
else:
|
242 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
243 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
244 |
+
init.constant_(m.bias.data, 0.0)
|
245 |
+
elif classname.find(
|
246 |
+
'BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
247 |
+
init.normal_(m.weight.data, 1.0, init_gain)
|
248 |
+
init.constant_(m.bias.data, 0.0)
|
249 |
+
|
250 |
+
print('initialize network with %s' % init_type)
|
251 |
+
net.apply(init_func) # apply the initialization function <init_func>
|
252 |
+
|
253 |
+
|
254 |
+
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
255 |
+
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
256 |
+
Parameters:
|
257 |
+
net (network) -- the network to be initialized
|
258 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
259 |
+
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
260 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
261 |
+
|
262 |
+
Return an initialized network.
|
263 |
+
"""
|
264 |
+
if len(gpu_ids) > 0:
|
265 |
+
assert (torch.cuda.is_available())
|
266 |
+
net.to(gpu_ids[0])
|
267 |
+
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
|
268 |
+
init_weights(net, init_type, init_gain=init_gain)
|
269 |
+
return net
|
270 |
+
|
271 |
+
|
272 |
+
def imageSpaceRotation(xy, rot):
|
273 |
+
'''
|
274 |
+
args:
|
275 |
+
xy: (B, 2, N) input
|
276 |
+
rot: (B, 2) x,y axis rotation angles
|
277 |
+
|
278 |
+
rotation center will be always image center (other rotation center can be represented by additional z translation)
|
279 |
+
'''
|
280 |
+
disp = rot.unsqueeze(2).sin().expand_as(xy)
|
281 |
+
return (disp * xy).sum(dim=1)
|
282 |
+
|
283 |
+
|
284 |
+
def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
|
285 |
+
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
|
286 |
+
|
287 |
+
Arguments:
|
288 |
+
netD (network) -- discriminator network
|
289 |
+
real_data (tensor array) -- real images
|
290 |
+
fake_data (tensor array) -- generated images from the generator
|
291 |
+
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
|
292 |
+
type (str) -- if we mix real and fake data or not [real | fake | mixed].
|
293 |
+
constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
|
294 |
+
lambda_gp (float) -- weight for this loss
|
295 |
+
|
296 |
+
Returns the gradient penalty loss
|
297 |
+
"""
|
298 |
+
if lambda_gp > 0.0:
|
299 |
+
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
|
300 |
+
interpolatesv = real_data
|
301 |
+
elif type == 'fake':
|
302 |
+
interpolatesv = fake_data
|
303 |
+
elif type == 'mixed':
|
304 |
+
alpha = torch.rand(real_data.shape[0], 1)
|
305 |
+
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(
|
306 |
+
*real_data.shape)
|
307 |
+
alpha = alpha.to(device)
|
308 |
+
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
|
309 |
+
else:
|
310 |
+
raise NotImplementedError('{} not implemented'.format(type))
|
311 |
+
interpolatesv.requires_grad_(True)
|
312 |
+
disc_interpolates = netD(interpolatesv)
|
313 |
+
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
|
314 |
+
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
|
315 |
+
create_graph=True, retain_graph=True, only_inputs=True)
|
316 |
+
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
|
317 |
+
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
|
318 |
+
return gradient_penalty, gradients
|
319 |
+
else:
|
320 |
+
return 0.0, None
|
321 |
+
|
322 |
+
def get_norm_layer(norm_type='instance'):
|
323 |
+
"""Return a normalization layer
|
324 |
+
Parameters:
|
325 |
+
norm_type (str) -- the name of the normalization layer: batch | instance | none
|
326 |
+
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
|
327 |
+
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
|
328 |
+
"""
|
329 |
+
if norm_type == 'batch':
|
330 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
331 |
+
elif norm_type == 'instance':
|
332 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
333 |
+
elif norm_type == 'group':
|
334 |
+
norm_layer = functools.partial(nn.GroupNorm, 32)
|
335 |
+
elif norm_type == 'none':
|
336 |
+
norm_layer = None
|
337 |
+
else:
|
338 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
339 |
+
return norm_layer
|
340 |
+
|
341 |
+
class Flatten(nn.Module):
|
342 |
+
def forward(self, input):
|
343 |
+
return input.view(input.size(0), -1)
|
344 |
+
|
345 |
+
class ConvBlock(nn.Module):
|
346 |
+
def __init__(self, in_planes, out_planes, norm='batch'):
|
347 |
+
super(ConvBlock, self).__init__()
|
348 |
+
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
|
349 |
+
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
|
350 |
+
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
|
351 |
+
|
352 |
+
if norm == 'batch':
|
353 |
+
self.bn1 = nn.BatchNorm2d(in_planes)
|
354 |
+
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
|
355 |
+
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
|
356 |
+
self.bn4 = nn.BatchNorm2d(in_planes)
|
357 |
+
elif norm == 'group':
|
358 |
+
self.bn1 = nn.GroupNorm(32, in_planes)
|
359 |
+
self.bn2 = nn.GroupNorm(32, int(out_planes / 2))
|
360 |
+
self.bn3 = nn.GroupNorm(32, int(out_planes / 4))
|
361 |
+
self.bn4 = nn.GroupNorm(32, in_planes)
|
362 |
+
|
363 |
+
if in_planes != out_planes:
|
364 |
+
self.downsample = nn.Sequential(
|
365 |
+
self.bn4,
|
366 |
+
nn.ReLU(True),
|
367 |
+
nn.Conv2d(in_planes, out_planes,
|
368 |
+
kernel_size=1, stride=1, bias=False),
|
369 |
+
)
|
370 |
+
else:
|
371 |
+
self.downsample = None
|
372 |
+
|
373 |
+
def forward(self, x):
|
374 |
+
residual = x
|
375 |
+
|
376 |
+
out1 = self.bn1(x)
|
377 |
+
out1 = F.relu(out1, True)
|
378 |
+
out1 = self.conv1(out1)
|
379 |
+
|
380 |
+
out2 = self.bn2(out1)
|
381 |
+
out2 = F.relu(out2, True)
|
382 |
+
out2 = self.conv2(out2)
|
383 |
+
|
384 |
+
out3 = self.bn3(out2)
|
385 |
+
out3 = F.relu(out3, True)
|
386 |
+
out3 = self.conv3(out3)
|
387 |
+
|
388 |
+
out3 = torch.cat((out1, out2, out3), 1)
|
389 |
+
|
390 |
+
if self.downsample is not None:
|
391 |
+
residual = self.downsample(residual)
|
392 |
+
|
393 |
+
out3 += residual
|
394 |
+
|
395 |
+
return out3
|
396 |
+
|
PIFu/lib/options.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
class BaseOptions():
|
6 |
+
def __init__(self):
|
7 |
+
self.initialized = False
|
8 |
+
argparse
|
9 |
+
def initialize(self, parser):
|
10 |
+
# Datasets related
|
11 |
+
g_data = parser.add_argument_group('Data')
|
12 |
+
g_data.add_argument('--dataroot', type=str, default='./data',
|
13 |
+
help='path to images (data folder)')
|
14 |
+
|
15 |
+
g_data.add_argument('--loadSize', type=int, default=512, help='load size of input image')
|
16 |
+
|
17 |
+
# Experiment related
|
18 |
+
g_exp = parser.add_argument_group('Experiment')
|
19 |
+
g_exp.add_argument('--name', type=str, default='example',
|
20 |
+
help='name of the experiment. It decides where to store samples and models')
|
21 |
+
g_exp.add_argument('--debug', action='store_true', help='debug mode or not')
|
22 |
+
|
23 |
+
g_exp.add_argument('--num_views', type=int, default=1, help='How many views to use for multiview network.')
|
24 |
+
g_exp.add_argument('--random_multiview', action='store_true', help='Select random multiview combination.')
|
25 |
+
|
26 |
+
# Training related
|
27 |
+
g_train = parser.add_argument_group('Training')
|
28 |
+
g_train.add_argument('--gpu_id', type=int, default=0, help='gpu id for cuda')
|
29 |
+
g_train.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2, -1 for CPU mode')
|
30 |
+
|
31 |
+
g_train.add_argument('--num_threads', default=1, type=int, help='# sthreads for loading data')
|
32 |
+
g_train.add_argument('--serial_batches', action='store_true',
|
33 |
+
help='if true, takes images in order to make batches, otherwise takes them randomly')
|
34 |
+
g_train.add_argument('--pin_memory', action='store_true', help='pin_memory')
|
35 |
+
|
36 |
+
g_train.add_argument('--batch_size', type=int, default=2, help='input batch size')
|
37 |
+
g_train.add_argument('--learning_rate', type=float, default=1e-3, help='adam learning rate')
|
38 |
+
g_train.add_argument('--learning_rateC', type=float, default=1e-3, help='adam learning rate')
|
39 |
+
g_train.add_argument('--num_epoch', type=int, default=100, help='num epoch to train')
|
40 |
+
|
41 |
+
g_train.add_argument('--freq_plot', type=int, default=10, help='freqency of the error plot')
|
42 |
+
g_train.add_argument('--freq_save', type=int, default=50, help='freqency of the save_checkpoints')
|
43 |
+
g_train.add_argument('--freq_save_ply', type=int, default=100, help='freqency of the save ply')
|
44 |
+
|
45 |
+
g_train.add_argument('--no_gen_mesh', action='store_true')
|
46 |
+
g_train.add_argument('--no_num_eval', action='store_true')
|
47 |
+
|
48 |
+
g_train.add_argument('--resume_epoch', type=int, default=-1, help='epoch resuming the training')
|
49 |
+
g_train.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
|
50 |
+
|
51 |
+
# Testing related
|
52 |
+
g_test = parser.add_argument_group('Testing')
|
53 |
+
g_test.add_argument('--resolution', type=int, default=256, help='# of grid in mesh reconstruction')
|
54 |
+
g_test.add_argument('--test_folder_path', type=str, default=None, help='the folder of test image')
|
55 |
+
|
56 |
+
# Sampling related
|
57 |
+
g_sample = parser.add_argument_group('Sampling')
|
58 |
+
g_sample.add_argument('--sigma', type=float, default=5.0, help='perturbation standard deviation for positions')
|
59 |
+
|
60 |
+
g_sample.add_argument('--num_sample_inout', type=int, default=5000, help='# of sampling points')
|
61 |
+
g_sample.add_argument('--num_sample_color', type=int, default=0, help='# of sampling points')
|
62 |
+
|
63 |
+
g_sample.add_argument('--z_size', type=float, default=200.0, help='z normalization factor')
|
64 |
+
|
65 |
+
# Model related
|
66 |
+
g_model = parser.add_argument_group('Model')
|
67 |
+
# General
|
68 |
+
g_model.add_argument('--norm', type=str, default='group',
|
69 |
+
help='instance normalization or batch normalization or group normalization')
|
70 |
+
g_model.add_argument('--norm_color', type=str, default='instance',
|
71 |
+
help='instance normalization or batch normalization or group normalization')
|
72 |
+
|
73 |
+
# hg filter specify
|
74 |
+
g_model.add_argument('--num_stack', type=int, default=4, help='# of hourglass')
|
75 |
+
g_model.add_argument('--num_hourglass', type=int, default=2, help='# of stacked layer of hourglass')
|
76 |
+
g_model.add_argument('--skip_hourglass', action='store_true', help='skip connection in hourglass')
|
77 |
+
g_model.add_argument('--hg_down', type=str, default='ave_pool', help='ave pool || conv64 || conv128')
|
78 |
+
g_model.add_argument('--hourglass_dim', type=int, default='256', help='256 | 512')
|
79 |
+
|
80 |
+
# Classification General
|
81 |
+
g_model.add_argument('--mlp_dim', nargs='+', default=[257, 1024, 512, 256, 128, 1], type=int,
|
82 |
+
help='# of dimensions of mlp')
|
83 |
+
g_model.add_argument('--mlp_dim_color', nargs='+', default=[513, 1024, 512, 256, 128, 3],
|
84 |
+
type=int, help='# of dimensions of color mlp')
|
85 |
+
|
86 |
+
g_model.add_argument('--use_tanh', action='store_true',
|
87 |
+
help='using tanh after last conv of image_filter network')
|
88 |
+
|
89 |
+
# for train
|
90 |
+
parser.add_argument('--random_flip', action='store_true', help='if random flip')
|
91 |
+
parser.add_argument('--random_trans', action='store_true', help='if random flip')
|
92 |
+
parser.add_argument('--random_scale', action='store_true', help='if random flip')
|
93 |
+
parser.add_argument('--no_residual', action='store_true', help='no skip connection in mlp')
|
94 |
+
parser.add_argument('--schedule', type=int, nargs='+', default=[60, 80],
|
95 |
+
help='Decrease learning rate at these epochs.')
|
96 |
+
parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
|
97 |
+
parser.add_argument('--color_loss_type', type=str, default='l1', help='mse | l1')
|
98 |
+
|
99 |
+
# for eval
|
100 |
+
parser.add_argument('--val_test_error', action='store_true', help='validate errors of test data')
|
101 |
+
parser.add_argument('--val_train_error', action='store_true', help='validate errors of train data')
|
102 |
+
parser.add_argument('--gen_test_mesh', action='store_true', help='generate test mesh')
|
103 |
+
parser.add_argument('--gen_train_mesh', action='store_true', help='generate train mesh')
|
104 |
+
parser.add_argument('--all_mesh', action='store_true', help='generate meshs from all hourglass output')
|
105 |
+
parser.add_argument('--num_gen_mesh_test', type=int, default=1,
|
106 |
+
help='how many meshes to generate during testing')
|
107 |
+
|
108 |
+
# path
|
109 |
+
parser.add_argument('--checkpoints_path', type=str, default='./checkpoints', help='path to save checkpoints')
|
110 |
+
parser.add_argument('--load_netG_checkpoint_path', type=str, default=None, help='path to save checkpoints')
|
111 |
+
parser.add_argument('--load_netC_checkpoint_path', type=str, default=None, help='path to save checkpoints')
|
112 |
+
parser.add_argument('--results_path', type=str, default='./results', help='path to save results ply')
|
113 |
+
parser.add_argument('--load_checkpoint_path', type=str, help='path to save results ply')
|
114 |
+
parser.add_argument('--single', type=str, default='', help='single data for training')
|
115 |
+
# for single image reconstruction
|
116 |
+
parser.add_argument('--mask_path', type=str, help='path for input mask')
|
117 |
+
parser.add_argument('--img_path', type=str, help='path for input image')
|
118 |
+
|
119 |
+
# aug
|
120 |
+
group_aug = parser.add_argument_group('aug')
|
121 |
+
group_aug.add_argument('--aug_alstd', type=float, default=0.0, help='augmentation pca lighting alpha std')
|
122 |
+
group_aug.add_argument('--aug_bri', type=float, default=0.0, help='augmentation brightness')
|
123 |
+
group_aug.add_argument('--aug_con', type=float, default=0.0, help='augmentation contrast')
|
124 |
+
group_aug.add_argument('--aug_sat', type=float, default=0.0, help='augmentation saturation')
|
125 |
+
group_aug.add_argument('--aug_hue', type=float, default=0.0, help='augmentation hue')
|
126 |
+
group_aug.add_argument('--aug_blur', type=float, default=0.0, help='augmentation blur')
|
127 |
+
|
128 |
+
# special tasks
|
129 |
+
self.initialized = True
|
130 |
+
return parser
|
131 |
+
|
132 |
+
def gather_options(self):
|
133 |
+
# initialize parser with basic options
|
134 |
+
if not self.initialized:
|
135 |
+
parser = argparse.ArgumentParser(
|
136 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
137 |
+
parser = self.initialize(parser)
|
138 |
+
|
139 |
+
self.parser = parser
|
140 |
+
|
141 |
+
return parser.parse_args()
|
142 |
+
|
143 |
+
def print_options(self, opt):
|
144 |
+
message = ''
|
145 |
+
message += '----------------- Options ---------------\n'
|
146 |
+
for k, v in sorted(vars(opt).items()):
|
147 |
+
comment = ''
|
148 |
+
default = self.parser.get_default(k)
|
149 |
+
if v != default:
|
150 |
+
comment = '\t[default: %s]' % str(default)
|
151 |
+
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
152 |
+
message += '----------------- End -------------------'
|
153 |
+
print(message)
|
154 |
+
|
155 |
+
def parse(self):
|
156 |
+
opt = self.gather_options()
|
157 |
+
return opt
|
158 |
+
|
159 |
+
def parse_to_dict(self):
|
160 |
+
opt = self.gather_options()
|
161 |
+
return opt.__dict__
|
PIFu/lib/renderer/__init__.py
ADDED
File without changes
|
PIFu/lib/renderer/camera.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from .glm import ortho
|
5 |
+
|
6 |
+
|
7 |
+
class Camera:
|
8 |
+
def __init__(self, width=1600, height=1200):
|
9 |
+
# Focal Length
|
10 |
+
# equivalent 50mm
|
11 |
+
focal = np.sqrt(width * width + height * height)
|
12 |
+
self.focal_x = focal
|
13 |
+
self.focal_y = focal
|
14 |
+
# Principal Point Offset
|
15 |
+
self.principal_x = width / 2
|
16 |
+
self.principal_y = height / 2
|
17 |
+
# Axis Skew
|
18 |
+
self.skew = 0
|
19 |
+
# Image Size
|
20 |
+
self.width = width
|
21 |
+
self.height = height
|
22 |
+
|
23 |
+
self.near = 1
|
24 |
+
self.far = 10
|
25 |
+
|
26 |
+
# Camera Center
|
27 |
+
self.center = np.array([0, 0, 1.6])
|
28 |
+
self.direction = np.array([0, 0, -1])
|
29 |
+
self.right = np.array([1, 0, 0])
|
30 |
+
self.up = np.array([0, 1, 0])
|
31 |
+
|
32 |
+
self.ortho_ratio = None
|
33 |
+
|
34 |
+
def sanity_check(self):
|
35 |
+
self.center = self.center.reshape([-1])
|
36 |
+
self.direction = self.direction.reshape([-1])
|
37 |
+
self.right = self.right.reshape([-1])
|
38 |
+
self.up = self.up.reshape([-1])
|
39 |
+
|
40 |
+
assert len(self.center) == 3
|
41 |
+
assert len(self.direction) == 3
|
42 |
+
assert len(self.right) == 3
|
43 |
+
assert len(self.up) == 3
|
44 |
+
|
45 |
+
@staticmethod
|
46 |
+
def normalize_vector(v):
|
47 |
+
v_norm = np.linalg.norm(v)
|
48 |
+
return v if v_norm == 0 else v / v_norm
|
49 |
+
|
50 |
+
def get_real_z_value(self, z):
|
51 |
+
z_near = self.near
|
52 |
+
z_far = self.far
|
53 |
+
z_n = 2.0 * z - 1.0
|
54 |
+
z_e = 2.0 * z_near * z_far / (z_far + z_near - z_n * (z_far - z_near))
|
55 |
+
return z_e
|
56 |
+
|
57 |
+
def get_rotation_matrix(self):
|
58 |
+
rot_mat = np.eye(3)
|
59 |
+
s = self.right
|
60 |
+
s = self.normalize_vector(s)
|
61 |
+
rot_mat[0, :] = s
|
62 |
+
u = self.up
|
63 |
+
u = self.normalize_vector(u)
|
64 |
+
rot_mat[1, :] = -u
|
65 |
+
rot_mat[2, :] = self.normalize_vector(self.direction)
|
66 |
+
|
67 |
+
return rot_mat
|
68 |
+
|
69 |
+
def get_translation_vector(self):
|
70 |
+
rot_mat = self.get_rotation_matrix()
|
71 |
+
trans = -np.dot(rot_mat, self.center)
|
72 |
+
return trans
|
73 |
+
|
74 |
+
def get_intrinsic_matrix(self):
|
75 |
+
int_mat = np.eye(3)
|
76 |
+
|
77 |
+
int_mat[0, 0] = self.focal_x
|
78 |
+
int_mat[1, 1] = self.focal_y
|
79 |
+
int_mat[0, 1] = self.skew
|
80 |
+
int_mat[0, 2] = self.principal_x
|
81 |
+
int_mat[1, 2] = self.principal_y
|
82 |
+
|
83 |
+
return int_mat
|
84 |
+
|
85 |
+
def get_projection_matrix(self):
|
86 |
+
ext_mat = self.get_extrinsic_matrix()
|
87 |
+
int_mat = self.get_intrinsic_matrix()
|
88 |
+
|
89 |
+
return np.matmul(int_mat, ext_mat)
|
90 |
+
|
91 |
+
def get_extrinsic_matrix(self):
|
92 |
+
rot_mat = self.get_rotation_matrix()
|
93 |
+
int_mat = self.get_intrinsic_matrix()
|
94 |
+
trans = self.get_translation_vector()
|
95 |
+
|
96 |
+
extrinsic = np.eye(4)
|
97 |
+
extrinsic[:3, :3] = rot_mat
|
98 |
+
extrinsic[:3, 3] = trans
|
99 |
+
|
100 |
+
return extrinsic[:3, :]
|
101 |
+
|
102 |
+
def set_rotation_matrix(self, rot_mat):
|
103 |
+
self.direction = rot_mat[2, :]
|
104 |
+
self.up = -rot_mat[1, :]
|
105 |
+
self.right = rot_mat[0, :]
|
106 |
+
|
107 |
+
def set_intrinsic_matrix(self, int_mat):
|
108 |
+
self.focal_x = int_mat[0, 0]
|
109 |
+
self.focal_y = int_mat[1, 1]
|
110 |
+
self.skew = int_mat[0, 1]
|
111 |
+
self.principal_x = int_mat[0, 2]
|
112 |
+
self.principal_y = int_mat[1, 2]
|
113 |
+
|
114 |
+
def set_projection_matrix(self, proj_mat):
|
115 |
+
res = cv2.decomposeProjectionMatrix(proj_mat)
|
116 |
+
int_mat, rot_mat, camera_center_homo = res[0], res[1], res[2]
|
117 |
+
camera_center = camera_center_homo[0:3] / camera_center_homo[3]
|
118 |
+
camera_center = camera_center.reshape(-1)
|
119 |
+
int_mat = int_mat / int_mat[2][2]
|
120 |
+
|
121 |
+
self.set_intrinsic_matrix(int_mat)
|
122 |
+
self.set_rotation_matrix(rot_mat)
|
123 |
+
self.center = camera_center
|
124 |
+
|
125 |
+
self.sanity_check()
|
126 |
+
|
127 |
+
def get_gl_matrix(self):
|
128 |
+
z_near = self.near
|
129 |
+
z_far = self.far
|
130 |
+
rot_mat = self.get_rotation_matrix()
|
131 |
+
int_mat = self.get_intrinsic_matrix()
|
132 |
+
trans = self.get_translation_vector()
|
133 |
+
|
134 |
+
extrinsic = np.eye(4)
|
135 |
+
extrinsic[:3, :3] = rot_mat
|
136 |
+
extrinsic[:3, 3] = trans
|
137 |
+
axis_adj = np.eye(4)
|
138 |
+
axis_adj[2, 2] = -1
|
139 |
+
axis_adj[1, 1] = -1
|
140 |
+
model_view = np.matmul(axis_adj, extrinsic)
|
141 |
+
|
142 |
+
projective = np.zeros([4, 4])
|
143 |
+
projective[:2, :2] = int_mat[:2, :2]
|
144 |
+
projective[:2, 2:3] = -int_mat[:2, 2:3]
|
145 |
+
projective[3, 2] = -1
|
146 |
+
projective[2, 2] = (z_near + z_far)
|
147 |
+
projective[2, 3] = (z_near * z_far)
|
148 |
+
|
149 |
+
if self.ortho_ratio is None:
|
150 |
+
ndc = ortho(0, self.width, 0, self.height, z_near, z_far)
|
151 |
+
perspective = np.matmul(ndc, projective)
|
152 |
+
else:
|
153 |
+
perspective = ortho(-self.width * self.ortho_ratio / 2, self.width * self.ortho_ratio / 2,
|
154 |
+
-self.height * self.ortho_ratio / 2, self.height * self.ortho_ratio / 2,
|
155 |
+
z_near, z_far)
|
156 |
+
|
157 |
+
return perspective, model_view
|
158 |
+
|
159 |
+
|
160 |
+
def KRT_from_P(proj_mat, normalize_K=True):
|
161 |
+
res = cv2.decomposeProjectionMatrix(proj_mat)
|
162 |
+
K, Rot, camera_center_homog = res[0], res[1], res[2]
|
163 |
+
camera_center = camera_center_homog[0:3] / camera_center_homog[3]
|
164 |
+
trans = -Rot.dot(camera_center)
|
165 |
+
if normalize_K:
|
166 |
+
K = K / K[2][2]
|
167 |
+
return K, Rot, trans
|
168 |
+
|
169 |
+
|
170 |
+
def MVP_from_P(proj_mat, width, height, near=0.1, far=10000):
|
171 |
+
'''
|
172 |
+
Convert OpenCV camera calibration matrix to OpenGL projection and model view matrix
|
173 |
+
:param proj_mat: OpenCV camera projeciton matrix
|
174 |
+
:param width: Image width
|
175 |
+
:param height: Image height
|
176 |
+
:param near: Z near value
|
177 |
+
:param far: Z far value
|
178 |
+
:return: OpenGL projection matrix and model view matrix
|
179 |
+
'''
|
180 |
+
res = cv2.decomposeProjectionMatrix(proj_mat)
|
181 |
+
K, Rot, camera_center_homog = res[0], res[1], res[2]
|
182 |
+
camera_center = camera_center_homog[0:3] / camera_center_homog[3]
|
183 |
+
trans = -Rot.dot(camera_center)
|
184 |
+
K = K / K[2][2]
|
185 |
+
|
186 |
+
extrinsic = np.eye(4)
|
187 |
+
extrinsic[:3, :3] = Rot
|
188 |
+
extrinsic[:3, 3:4] = trans
|
189 |
+
axis_adj = np.eye(4)
|
190 |
+
axis_adj[2, 2] = -1
|
191 |
+
axis_adj[1, 1] = -1
|
192 |
+
model_view = np.matmul(axis_adj, extrinsic)
|
193 |
+
|
194 |
+
zFar = far
|
195 |
+
zNear = near
|
196 |
+
projective = np.zeros([4, 4])
|
197 |
+
projective[:2, :2] = K[:2, :2]
|
198 |
+
projective[:2, 2:3] = -K[:2, 2:3]
|
199 |
+
projective[3, 2] = -1
|
200 |
+
projective[2, 2] = (zNear + zFar)
|
201 |
+
projective[2, 3] = (zNear * zFar)
|
202 |
+
|
203 |
+
ndc = ortho(0, width, 0, height, zNear, zFar)
|
204 |
+
|
205 |
+
perspective = np.matmul(ndc, projective)
|
206 |
+
|
207 |
+
return perspective, model_view
|
PIFu/lib/renderer/gl/__init__.py
ADDED
File without changes
|
PIFu/lib/renderer/gl/cam_render.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .render import Render
|
2 |
+
|
3 |
+
GLUT = None
|
4 |
+
|
5 |
+
class CamRender(Render):
|
6 |
+
def __init__(self, width=1600, height=1200, name='Cam Renderer',
|
7 |
+
program_files=['simple.fs', 'simple.vs'], color_size=1, ms_rate=1, egl=False):
|
8 |
+
Render.__init__(self, width, height, name, program_files, color_size, ms_rate=ms_rate, egl=egl)
|
9 |
+
self.camera = None
|
10 |
+
|
11 |
+
if not egl:
|
12 |
+
global GLUT
|
13 |
+
import OpenGL.GLUT as GLUT
|
14 |
+
GLUT.glutDisplayFunc(self.display)
|
15 |
+
GLUT.glutKeyboardFunc(self.keyboard)
|
16 |
+
|
17 |
+
def set_camera(self, camera):
|
18 |
+
self.camera = camera
|
19 |
+
self.projection_matrix, self.model_view_matrix = camera.get_gl_matrix()
|
20 |
+
|
21 |
+
def keyboard(self, key, x, y):
|
22 |
+
# up
|
23 |
+
eps = 1
|
24 |
+
# print(key)
|
25 |
+
if key == b'w':
|
26 |
+
self.camera.center += eps * self.camera.direction
|
27 |
+
elif key == b's':
|
28 |
+
self.camera.center -= eps * self.camera.direction
|
29 |
+
if key == b'a':
|
30 |
+
self.camera.center -= eps * self.camera.right
|
31 |
+
elif key == b'd':
|
32 |
+
self.camera.center += eps * self.camera.right
|
33 |
+
if key == b' ':
|
34 |
+
self.camera.center += eps * self.camera.up
|
35 |
+
elif key == b'x':
|
36 |
+
self.camera.center -= eps * self.camera.up
|
37 |
+
elif key == b'i':
|
38 |
+
self.camera.near += 0.1 * eps
|
39 |
+
self.camera.far += 0.1 * eps
|
40 |
+
elif key == b'o':
|
41 |
+
self.camera.near -= 0.1 * eps
|
42 |
+
self.camera.far -= 0.1 * eps
|
43 |
+
|
44 |
+
self.projection_matrix, self.model_view_matrix = self.camera.get_gl_matrix()
|
45 |
+
|
46 |
+
def show(self):
|
47 |
+
if GLUT is not None:
|
48 |
+
GLUT.glutMainLoop()
|
PIFu/lib/renderer/gl/data/prt.fs
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#version 330
|
2 |
+
|
3 |
+
uniform vec3 SHCoeffs[9];
|
4 |
+
uniform uint analytic;
|
5 |
+
|
6 |
+
uniform uint hasNormalMap;
|
7 |
+
uniform uint hasAlbedoMap;
|
8 |
+
|
9 |
+
uniform sampler2D AlbedoMap;
|
10 |
+
uniform sampler2D NormalMap;
|
11 |
+
|
12 |
+
in VertexData {
|
13 |
+
vec3 Position;
|
14 |
+
vec3 Depth;
|
15 |
+
vec3 ModelNormal;
|
16 |
+
vec2 Texcoord;
|
17 |
+
vec3 Tangent;
|
18 |
+
vec3 Bitangent;
|
19 |
+
vec3 PRT1;
|
20 |
+
vec3 PRT2;
|
21 |
+
vec3 PRT3;
|
22 |
+
} VertexIn;
|
23 |
+
|
24 |
+
layout (location = 0) out vec4 FragColor;
|
25 |
+
layout (location = 1) out vec4 FragNormal;
|
26 |
+
layout (location = 2) out vec4 FragPosition;
|
27 |
+
layout (location = 3) out vec4 FragAlbedo;
|
28 |
+
layout (location = 4) out vec4 FragShading;
|
29 |
+
layout (location = 5) out vec4 FragPRT1;
|
30 |
+
layout (location = 6) out vec4 FragPRT2;
|
31 |
+
layout (location = 7) out vec4 FragPRT3;
|
32 |
+
|
33 |
+
vec4 gammaCorrection(vec4 vec, float g)
|
34 |
+
{
|
35 |
+
return vec4(pow(vec.x, 1.0/g), pow(vec.y, 1.0/g), pow(vec.z, 1.0/g), vec.w);
|
36 |
+
}
|
37 |
+
|
38 |
+
vec3 gammaCorrection(vec3 vec, float g)
|
39 |
+
{
|
40 |
+
return vec3(pow(vec.x, 1.0/g), pow(vec.y, 1.0/g), pow(vec.z, 1.0/g));
|
41 |
+
}
|
42 |
+
|
43 |
+
void evaluateH(vec3 n, out float H[9])
|
44 |
+
{
|
45 |
+
float c1 = 0.429043, c2 = 0.511664,
|
46 |
+
c3 = 0.743125, c4 = 0.886227, c5 = 0.247708;
|
47 |
+
|
48 |
+
H[0] = c4;
|
49 |
+
H[1] = 2.0 * c2 * n[1];
|
50 |
+
H[2] = 2.0 * c2 * n[2];
|
51 |
+
H[3] = 2.0 * c2 * n[0];
|
52 |
+
H[4] = 2.0 * c1 * n[0] * n[1];
|
53 |
+
H[5] = 2.0 * c1 * n[1] * n[2];
|
54 |
+
H[6] = c3 * n[2] * n[2] - c5;
|
55 |
+
H[7] = 2.0 * c1 * n[2] * n[0];
|
56 |
+
H[8] = c1 * (n[0] * n[0] - n[1] * n[1]);
|
57 |
+
}
|
58 |
+
|
59 |
+
vec3 evaluateLightingModel(vec3 normal)
|
60 |
+
{
|
61 |
+
float H[9];
|
62 |
+
evaluateH(normal, H);
|
63 |
+
vec3 res = vec3(0.0);
|
64 |
+
for (int i = 0; i < 9; i++) {
|
65 |
+
res += H[i] * SHCoeffs[i];
|
66 |
+
}
|
67 |
+
return res;
|
68 |
+
}
|
69 |
+
|
70 |
+
// nC: coarse geometry normal, nH: fine normal from normal map
|
71 |
+
vec3 evaluateLightingModelHybrid(vec3 nC, vec3 nH, mat3 prt)
|
72 |
+
{
|
73 |
+
float HC[9], HH[9];
|
74 |
+
evaluateH(nC, HC);
|
75 |
+
evaluateH(nH, HH);
|
76 |
+
|
77 |
+
vec3 res = vec3(0.0);
|
78 |
+
vec3 shadow = vec3(0.0);
|
79 |
+
vec3 unshadow = vec3(0.0);
|
80 |
+
for(int i = 0; i < 3; ++i){
|
81 |
+
for(int j = 0; j < 3; ++j){
|
82 |
+
int id = i*3+j;
|
83 |
+
res += HH[id]* SHCoeffs[id];
|
84 |
+
shadow += prt[i][j] * SHCoeffs[id];
|
85 |
+
unshadow += HC[id] * SHCoeffs[id];
|
86 |
+
}
|
87 |
+
}
|
88 |
+
vec3 ratio = clamp(shadow/unshadow,0.0,1.0);
|
89 |
+
res = ratio * res;
|
90 |
+
|
91 |
+
return res;
|
92 |
+
}
|
93 |
+
|
94 |
+
vec3 evaluateLightingModelPRT(mat3 prt)
|
95 |
+
{
|
96 |
+
vec3 res = vec3(0.0);
|
97 |
+
for(int i = 0; i < 3; ++i){
|
98 |
+
for(int j = 0; j < 3; ++j){
|
99 |
+
res += prt[i][j] * SHCoeffs[i*3+j];
|
100 |
+
}
|
101 |
+
}
|
102 |
+
|
103 |
+
return res;
|
104 |
+
}
|
105 |
+
|
106 |
+
void main()
|
107 |
+
{
|
108 |
+
vec2 uv = VertexIn.Texcoord;
|
109 |
+
vec3 nC = normalize(VertexIn.ModelNormal);
|
110 |
+
vec3 nml = nC;
|
111 |
+
mat3 prt = mat3(VertexIn.PRT1, VertexIn.PRT2, VertexIn.PRT3);
|
112 |
+
|
113 |
+
if(hasAlbedoMap == uint(0))
|
114 |
+
FragAlbedo = vec4(1.0);
|
115 |
+
else
|
116 |
+
FragAlbedo = texture(AlbedoMap, uv);//gammaCorrection(texture(AlbedoMap, uv), 1.0/2.2);
|
117 |
+
|
118 |
+
if(hasNormalMap == uint(0))
|
119 |
+
{
|
120 |
+
if(analytic == uint(0))
|
121 |
+
FragShading = vec4(evaluateLightingModelPRT(prt), 1.0f);
|
122 |
+
else
|
123 |
+
FragShading = vec4(evaluateLightingModel(nC), 1.0f);
|
124 |
+
}
|
125 |
+
else
|
126 |
+
{
|
127 |
+
vec3 n_tan = normalize(texture(NormalMap, uv).rgb*2.0-vec3(1.0));
|
128 |
+
|
129 |
+
mat3 TBN = mat3(normalize(VertexIn.Tangent),normalize(VertexIn.Bitangent),nC);
|
130 |
+
vec3 nH = normalize(TBN * n_tan);
|
131 |
+
|
132 |
+
if(analytic == uint(0))
|
133 |
+
FragShading = vec4(evaluateLightingModelHybrid(nC,nH,prt),1.0f);
|
134 |
+
else
|
135 |
+
FragShading = vec4(evaluateLightingModel(nH), 1.0f);
|
136 |
+
|
137 |
+
nml = nH;
|
138 |
+
}
|
139 |
+
|
140 |
+
FragShading = gammaCorrection(FragShading, 2.2);
|
141 |
+
FragColor = clamp(FragAlbedo * FragShading, 0.0, 1.0);
|
142 |
+
FragNormal = vec4(0.5*(nml+vec3(1.0)), 1.0);
|
143 |
+
FragPosition = vec4(VertexIn.Position,VertexIn.Depth.x);
|
144 |
+
FragShading = vec4(clamp(0.5*FragShading.xyz, 0.0, 1.0),1.0);
|
145 |
+
// FragColor = gammaCorrection(clamp(FragAlbedo * FragShading, 0.0, 1.0),2.2);
|
146 |
+
// FragNormal = vec4(0.5*(nml+vec3(1.0)), 1.0);
|
147 |
+
// FragPosition = vec4(VertexIn.Position,VertexIn.Depth.x);
|
148 |
+
// FragShading = vec4(gammaCorrection(clamp(0.5*FragShading.xyz, 0.0, 1.0),2.2),1.0);
|
149 |
+
// FragAlbedo = gammaCorrection(FragAlbedo,2.2);
|
150 |
+
FragPRT1 = vec4(VertexIn.PRT1,1.0);
|
151 |
+
FragPRT2 = vec4(VertexIn.PRT2,1.0);
|
152 |
+
FragPRT3 = vec4(VertexIn.PRT3,1.0);
|
153 |
+
}
|
PIFu/lib/renderer/gl/data/prt.vs
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#version 330
|
2 |
+
|
3 |
+
layout (location = 0) in vec3 a_Position;
|
4 |
+
layout (location = 1) in vec3 a_Normal;
|
5 |
+
layout (location = 2) in vec2 a_TextureCoord;
|
6 |
+
layout (location = 3) in vec3 a_Tangent;
|
7 |
+
layout (location = 4) in vec3 a_Bitangent;
|
8 |
+
layout (location = 5) in vec3 a_PRT1;
|
9 |
+
layout (location = 6) in vec3 a_PRT2;
|
10 |
+
layout (location = 7) in vec3 a_PRT3;
|
11 |
+
|
12 |
+
out VertexData {
|
13 |
+
vec3 Position;
|
14 |
+
vec3 Depth;
|
15 |
+
vec3 ModelNormal;
|
16 |
+
vec2 Texcoord;
|
17 |
+
vec3 Tangent;
|
18 |
+
vec3 Bitangent;
|
19 |
+
vec3 PRT1;
|
20 |
+
vec3 PRT2;
|
21 |
+
vec3 PRT3;
|
22 |
+
} VertexOut;
|
23 |
+
|
24 |
+
uniform mat3 RotMat;
|
25 |
+
uniform mat4 NormMat;
|
26 |
+
uniform mat4 ModelMat;
|
27 |
+
uniform mat4 PerspMat;
|
28 |
+
|
29 |
+
float s_c3 = 0.94617469575; // (3*sqrt(5))/(4*sqrt(pi))
|
30 |
+
float s_c4 = -0.31539156525;// (-sqrt(5))/(4*sqrt(pi))
|
31 |
+
float s_c5 = 0.54627421529; // (sqrt(15))/(4*sqrt(pi))
|
32 |
+
|
33 |
+
float s_c_scale = 1.0/0.91529123286551084;
|
34 |
+
float s_c_scale_inv = 0.91529123286551084;
|
35 |
+
|
36 |
+
float s_rc2 = 1.5853309190550713*s_c_scale;
|
37 |
+
float s_c4_div_c3 = s_c4/s_c3;
|
38 |
+
float s_c4_div_c3_x2 = (s_c4/s_c3)*2.0;
|
39 |
+
|
40 |
+
float s_scale_dst2 = s_c3 * s_c_scale_inv;
|
41 |
+
float s_scale_dst4 = s_c5 * s_c_scale_inv;
|
42 |
+
|
43 |
+
void OptRotateBand0(float x[1], mat3 R, out float dst[1])
|
44 |
+
{
|
45 |
+
dst[0] = x[0];
|
46 |
+
}
|
47 |
+
|
48 |
+
// 9 multiplies
|
49 |
+
void OptRotateBand1(float x[3], mat3 R, out float dst[3])
|
50 |
+
{
|
51 |
+
// derived from SlowRotateBand1
|
52 |
+
dst[0] = ( R[1][1])*x[0] + (-R[1][2])*x[1] + ( R[1][0])*x[2];
|
53 |
+
dst[1] = (-R[2][1])*x[0] + ( R[2][2])*x[1] + (-R[2][0])*x[2];
|
54 |
+
dst[2] = ( R[0][1])*x[0] + (-R[0][2])*x[1] + ( R[0][0])*x[2];
|
55 |
+
}
|
56 |
+
|
57 |
+
// 48 multiplies
|
58 |
+
void OptRotateBand2(float x[5], mat3 R, out float dst[5])
|
59 |
+
{
|
60 |
+
// Sparse matrix multiply
|
61 |
+
float sh0 = x[3] + x[4] + x[4] - x[1];
|
62 |
+
float sh1 = x[0] + s_rc2*x[2] + x[3] + x[4];
|
63 |
+
float sh2 = x[0];
|
64 |
+
float sh3 = -x[3];
|
65 |
+
float sh4 = -x[1];
|
66 |
+
|
67 |
+
// Rotations. R0 and R1 just use the raw matrix columns
|
68 |
+
float r2x = R[0][0] + R[0][1];
|
69 |
+
float r2y = R[1][0] + R[1][1];
|
70 |
+
float r2z = R[2][0] + R[2][1];
|
71 |
+
|
72 |
+
float r3x = R[0][0] + R[0][2];
|
73 |
+
float r3y = R[1][0] + R[1][2];
|
74 |
+
float r3z = R[2][0] + R[2][2];
|
75 |
+
|
76 |
+
float r4x = R[0][1] + R[0][2];
|
77 |
+
float r4y = R[1][1] + R[1][2];
|
78 |
+
float r4z = R[2][1] + R[2][2];
|
79 |
+
|
80 |
+
// dense matrix multiplication one column at a time
|
81 |
+
|
82 |
+
// column 0
|
83 |
+
float sh0_x = sh0 * R[0][0];
|
84 |
+
float sh0_y = sh0 * R[1][0];
|
85 |
+
float d0 = sh0_x * R[1][0];
|
86 |
+
float d1 = sh0_y * R[2][0];
|
87 |
+
float d2 = sh0 * (R[2][0] * R[2][0] + s_c4_div_c3);
|
88 |
+
float d3 = sh0_x * R[2][0];
|
89 |
+
float d4 = sh0_x * R[0][0] - sh0_y * R[1][0];
|
90 |
+
|
91 |
+
// column 1
|
92 |
+
float sh1_x = sh1 * R[0][2];
|
93 |
+
float sh1_y = sh1 * R[1][2];
|
94 |
+
d0 += sh1_x * R[1][2];
|
95 |
+
d1 += sh1_y * R[2][2];
|
96 |
+
d2 += sh1 * (R[2][2] * R[2][2] + s_c4_div_c3);
|
97 |
+
d3 += sh1_x * R[2][2];
|
98 |
+
d4 += sh1_x * R[0][2] - sh1_y * R[1][2];
|
99 |
+
|
100 |
+
// column 2
|
101 |
+
float sh2_x = sh2 * r2x;
|
102 |
+
float sh2_y = sh2 * r2y;
|
103 |
+
d0 += sh2_x * r2y;
|
104 |
+
d1 += sh2_y * r2z;
|
105 |
+
d2 += sh2 * (r2z * r2z + s_c4_div_c3_x2);
|
106 |
+
d3 += sh2_x * r2z;
|
107 |
+
d4 += sh2_x * r2x - sh2_y * r2y;
|
108 |
+
|
109 |
+
// column 3
|
110 |
+
float sh3_x = sh3 * r3x;
|
111 |
+
float sh3_y = sh3 * r3y;
|
112 |
+
d0 += sh3_x * r3y;
|
113 |
+
d1 += sh3_y * r3z;
|
114 |
+
d2 += sh3 * (r3z * r3z + s_c4_div_c3_x2);
|
115 |
+
d3 += sh3_x * r3z;
|
116 |
+
d4 += sh3_x * r3x - sh3_y * r3y;
|
117 |
+
|
118 |
+
// column 4
|
119 |
+
float sh4_x = sh4 * r4x;
|
120 |
+
float sh4_y = sh4 * r4y;
|
121 |
+
d0 += sh4_x * r4y;
|
122 |
+
d1 += sh4_y * r4z;
|
123 |
+
d2 += sh4 * (r4z * r4z + s_c4_div_c3_x2);
|
124 |
+
d3 += sh4_x * r4z;
|
125 |
+
d4 += sh4_x * r4x - sh4_y * r4y;
|
126 |
+
|
127 |
+
// extra multipliers
|
128 |
+
dst[0] = d0;
|
129 |
+
dst[1] = -d1;
|
130 |
+
dst[2] = d2 * s_scale_dst2;
|
131 |
+
dst[3] = -d3;
|
132 |
+
dst[4] = d4 * s_scale_dst4;
|
133 |
+
}
|
134 |
+
|
135 |
+
void main()
|
136 |
+
{
|
137 |
+
// normalization
|
138 |
+
vec3 pos = (NormMat * vec4(a_Position,1.0)).xyz;
|
139 |
+
|
140 |
+
mat3 R = mat3(ModelMat) * RotMat;
|
141 |
+
VertexOut.ModelNormal = (R * a_Normal);
|
142 |
+
VertexOut.Position = R * pos;
|
143 |
+
VertexOut.Texcoord = a_TextureCoord;
|
144 |
+
VertexOut.Tangent = (R * a_Tangent);
|
145 |
+
VertexOut.Bitangent = (R * a_Bitangent);
|
146 |
+
float PRT0, PRT1[3], PRT2[5];
|
147 |
+
PRT0 = a_PRT1[0];
|
148 |
+
PRT1[0] = a_PRT1[1];
|
149 |
+
PRT1[1] = a_PRT1[2];
|
150 |
+
PRT1[2] = a_PRT2[0];
|
151 |
+
PRT2[0] = a_PRT2[1];
|
152 |
+
PRT2[1] = a_PRT2[2];
|
153 |
+
PRT2[2] = a_PRT3[0];
|
154 |
+
PRT2[3] = a_PRT3[1];
|
155 |
+
PRT2[4] = a_PRT3[2];
|
156 |
+
|
157 |
+
OptRotateBand1(PRT1, R, PRT1);
|
158 |
+
OptRotateBand2(PRT2, R, PRT2);
|
159 |
+
|
160 |
+
VertexOut.PRT1 = vec3(PRT0,PRT1[0],PRT1[1]);
|
161 |
+
VertexOut.PRT2 = vec3(PRT1[2],PRT2[0],PRT2[1]);
|
162 |
+
VertexOut.PRT3 = vec3(PRT2[2],PRT2[3],PRT2[4]);
|
163 |
+
|
164 |
+
gl_Position = PerspMat * ModelMat * vec4(RotMat * pos, 1.0);
|
165 |
+
|
166 |
+
VertexOut.Depth = vec3(gl_Position.z / gl_Position.w);
|
167 |
+
}
|
PIFu/lib/renderer/gl/data/prt_uv.fs
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#version 330
|
2 |
+
|
3 |
+
uniform vec3 SHCoeffs[9];
|
4 |
+
uniform uint analytic;
|
5 |
+
|
6 |
+
uniform uint hasNormalMap;
|
7 |
+
uniform uint hasAlbedoMap;
|
8 |
+
|
9 |
+
uniform sampler2D AlbedoMap;
|
10 |
+
uniform sampler2D NormalMap;
|
11 |
+
|
12 |
+
in VertexData {
|
13 |
+
vec3 Position;
|
14 |
+
vec3 ModelNormal;
|
15 |
+
vec3 CameraNormal;
|
16 |
+
vec2 Texcoord;
|
17 |
+
vec3 Tangent;
|
18 |
+
vec3 Bitangent;
|
19 |
+
vec3 PRT1;
|
20 |
+
vec3 PRT2;
|
21 |
+
vec3 PRT3;
|
22 |
+
} VertexIn;
|
23 |
+
|
24 |
+
layout (location = 0) out vec4 FragColor;
|
25 |
+
layout (location = 1) out vec4 FragPosition;
|
26 |
+
layout (location = 2) out vec4 FragNormal;
|
27 |
+
|
28 |
+
vec4 gammaCorrection(vec4 vec, float g)
|
29 |
+
{
|
30 |
+
return vec4(pow(vec.x, 1.0/g), pow(vec.y, 1.0/g), pow(vec.z, 1.0/g), vec.w);
|
31 |
+
}
|
32 |
+
|
33 |
+
vec3 gammaCorrection(vec3 vec, float g)
|
34 |
+
{
|
35 |
+
return vec3(pow(vec.x, 1.0/g), pow(vec.y, 1.0/g), pow(vec.z, 1.0/g));
|
36 |
+
}
|
37 |
+
|
38 |
+
void evaluateH(vec3 n, out float H[9])
|
39 |
+
{
|
40 |
+
float c1 = 0.429043, c2 = 0.511664,
|
41 |
+
c3 = 0.743125, c4 = 0.886227, c5 = 0.247708;
|
42 |
+
|
43 |
+
H[0] = c4;
|
44 |
+
H[1] = 2.0 * c2 * n[1];
|
45 |
+
H[2] = 2.0 * c2 * n[2];
|
46 |
+
H[3] = 2.0 * c2 * n[0];
|
47 |
+
H[4] = 2.0 * c1 * n[0] * n[1];
|
48 |
+
H[5] = 2.0 * c1 * n[1] * n[2];
|
49 |
+
H[6] = c3 * n[2] * n[2] - c5;
|
50 |
+
H[7] = 2.0 * c1 * n[2] * n[0];
|
51 |
+
H[8] = c1 * (n[0] * n[0] - n[1] * n[1]);
|
52 |
+
}
|
53 |
+
|
54 |
+
vec3 evaluateLightingModel(vec3 normal)
|
55 |
+
{
|
56 |
+
float H[9];
|
57 |
+
evaluateH(normal, H);
|
58 |
+
vec3 res = vec3(0.0);
|
59 |
+
for (int i = 0; i < 9; i++) {
|
60 |
+
res += H[i] * SHCoeffs[i];
|
61 |
+
}
|
62 |
+
return res;
|
63 |
+
}
|
64 |
+
|
65 |
+
// nC: coarse geometry normal, nH: fine normal from normal map
|
66 |
+
vec3 evaluateLightingModelHybrid(vec3 nC, vec3 nH, mat3 prt)
|
67 |
+
{
|
68 |
+
float HC[9], HH[9];
|
69 |
+
evaluateH(nC, HC);
|
70 |
+
evaluateH(nH, HH);
|
71 |
+
|
72 |
+
vec3 res = vec3(0.0);
|
73 |
+
vec3 shadow = vec3(0.0);
|
74 |
+
vec3 unshadow = vec3(0.0);
|
75 |
+
for(int i = 0; i < 3; ++i){
|
76 |
+
for(int j = 0; j < 3; ++j){
|
77 |
+
int id = i*3+j;
|
78 |
+
res += HH[id]* SHCoeffs[id];
|
79 |
+
shadow += prt[i][j] * SHCoeffs[id];
|
80 |
+
unshadow += HC[id] * SHCoeffs[id];
|
81 |
+
}
|
82 |
+
}
|
83 |
+
vec3 ratio = clamp(shadow/unshadow,0.0,1.0);
|
84 |
+
res = ratio * res;
|
85 |
+
|
86 |
+
return res;
|
87 |
+
}
|
88 |
+
|
89 |
+
vec3 evaluateLightingModelPRT(mat3 prt)
|
90 |
+
{
|
91 |
+
vec3 res = vec3(0.0);
|
92 |
+
for(int i = 0; i < 3; ++i){
|
93 |
+
for(int j = 0; j < 3; ++j){
|
94 |
+
res += prt[i][j] * SHCoeffs[i*3+j];
|
95 |
+
}
|
96 |
+
}
|
97 |
+
|
98 |
+
return res;
|
99 |
+
}
|
100 |
+
|
101 |
+
void main()
|
102 |
+
{
|
103 |
+
vec2 uv = VertexIn.Texcoord;
|
104 |
+
vec3 nM = normalize(VertexIn.ModelNormal);
|
105 |
+
vec3 nC = normalize(VertexIn.CameraNormal);
|
106 |
+
vec3 nml = nC;
|
107 |
+
mat3 prt = mat3(VertexIn.PRT1, VertexIn.PRT2, VertexIn.PRT3);
|
108 |
+
|
109 |
+
vec4 albedo, shading;
|
110 |
+
if(hasAlbedoMap == uint(0))
|
111 |
+
albedo = vec4(1.0);
|
112 |
+
else
|
113 |
+
albedo = texture(AlbedoMap, uv);//gammaCorrection(texture(AlbedoMap, uv), 1.0/2.2);
|
114 |
+
|
115 |
+
if(hasNormalMap == uint(0))
|
116 |
+
{
|
117 |
+
if(analytic == uint(0))
|
118 |
+
shading = vec4(evaluateLightingModelPRT(prt), 1.0f);
|
119 |
+
else
|
120 |
+
shading = vec4(evaluateLightingModel(nC), 1.0f);
|
121 |
+
}
|
122 |
+
else
|
123 |
+
{
|
124 |
+
vec3 n_tan = normalize(texture(NormalMap, uv).rgb*2.0-vec3(1.0));
|
125 |
+
|
126 |
+
mat3 TBN = mat3(normalize(VertexIn.Tangent),normalize(VertexIn.Bitangent),nC);
|
127 |
+
vec3 nH = normalize(TBN * n_tan);
|
128 |
+
|
129 |
+
if(analytic == uint(0))
|
130 |
+
shading = vec4(evaluateLightingModelHybrid(nC,nH,prt),1.0f);
|
131 |
+
else
|
132 |
+
shading = vec4(evaluateLightingModel(nH), 1.0f);
|
133 |
+
|
134 |
+
nml = nH;
|
135 |
+
}
|
136 |
+
|
137 |
+
shading = gammaCorrection(shading, 2.2);
|
138 |
+
FragColor = clamp(albedo * shading, 0.0, 1.0);
|
139 |
+
FragPosition = vec4(VertexIn.Position,1.0);
|
140 |
+
FragNormal = vec4(0.5*(nM+vec3(1.0)),1.0);
|
141 |
+
}
|
PIFu/lib/renderer/gl/data/prt_uv.vs
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#version 330
|
2 |
+
|
3 |
+
layout (location = 0) in vec3 a_Position;
|
4 |
+
layout (location = 1) in vec3 a_Normal;
|
5 |
+
layout (location = 2) in vec2 a_TextureCoord;
|
6 |
+
layout (location = 3) in vec3 a_Tangent;
|
7 |
+
layout (location = 4) in vec3 a_Bitangent;
|
8 |
+
layout (location = 5) in vec3 a_PRT1;
|
9 |
+
layout (location = 6) in vec3 a_PRT2;
|
10 |
+
layout (location = 7) in vec3 a_PRT3;
|
11 |
+
|
12 |
+
out VertexData {
|
13 |
+
vec3 Position;
|
14 |
+
vec3 ModelNormal;
|
15 |
+
vec3 CameraNormal;
|
16 |
+
vec2 Texcoord;
|
17 |
+
vec3 Tangent;
|
18 |
+
vec3 Bitangent;
|
19 |
+
vec3 PRT1;
|
20 |
+
vec3 PRT2;
|
21 |
+
vec3 PRT3;
|
22 |
+
} VertexOut;
|
23 |
+
|
24 |
+
uniform mat3 RotMat;
|
25 |
+
uniform mat4 NormMat;
|
26 |
+
uniform mat4 ModelMat;
|
27 |
+
uniform mat4 PerspMat;
|
28 |
+
|
29 |
+
#define pi 3.1415926535897932384626433832795
|
30 |
+
|
31 |
+
float s_c3 = 0.94617469575; // (3*sqrt(5))/(4*sqrt(pi))
|
32 |
+
float s_c4 = -0.31539156525;// (-sqrt(5))/(4*sqrt(pi))
|
33 |
+
float s_c5 = 0.54627421529; // (sqrt(15))/(4*sqrt(pi))
|
34 |
+
|
35 |
+
float s_c_scale = 1.0/0.91529123286551084;
|
36 |
+
float s_c_scale_inv = 0.91529123286551084;
|
37 |
+
|
38 |
+
float s_rc2 = 1.5853309190550713*s_c_scale;
|
39 |
+
float s_c4_div_c3 = s_c4/s_c3;
|
40 |
+
float s_c4_div_c3_x2 = (s_c4/s_c3)*2.0;
|
41 |
+
|
42 |
+
float s_scale_dst2 = s_c3 * s_c_scale_inv;
|
43 |
+
float s_scale_dst4 = s_c5 * s_c_scale_inv;
|
44 |
+
|
45 |
+
void OptRotateBand0(float x[1], mat3 R, out float dst[1])
|
46 |
+
{
|
47 |
+
dst[0] = x[0];
|
48 |
+
}
|
49 |
+
|
50 |
+
// 9 multiplies
|
51 |
+
void OptRotateBand1(float x[3], mat3 R, out float dst[3])
|
52 |
+
{
|
53 |
+
// derived from SlowRotateBand1
|
54 |
+
dst[0] = ( R[1][1])*x[0] + (-R[1][2])*x[1] + ( R[1][0])*x[2];
|
55 |
+
dst[1] = (-R[2][1])*x[0] + ( R[2][2])*x[1] + (-R[2][0])*x[2];
|
56 |
+
dst[2] = ( R[0][1])*x[0] + (-R[0][2])*x[1] + ( R[0][0])*x[2];
|
57 |
+
}
|
58 |
+
|
59 |
+
// 48 multiplies
|
60 |
+
void OptRotateBand2(float x[5], mat3 R, out float dst[5])
|
61 |
+
{
|
62 |
+
// Sparse matrix multiply
|
63 |
+
float sh0 = x[3] + x[4] + x[4] - x[1];
|
64 |
+
float sh1 = x[0] + s_rc2*x[2] + x[3] + x[4];
|
65 |
+
float sh2 = x[0];
|
66 |
+
float sh3 = -x[3];
|
67 |
+
float sh4 = -x[1];
|
68 |
+
|
69 |
+
// Rotations. R0 and R1 just use the raw matrix columns
|
70 |
+
float r2x = R[0][0] + R[0][1];
|
71 |
+
float r2y = R[1][0] + R[1][1];
|
72 |
+
float r2z = R[2][0] + R[2][1];
|
73 |
+
|
74 |
+
float r3x = R[0][0] + R[0][2];
|
75 |
+
float r3y = R[1][0] + R[1][2];
|
76 |
+
float r3z = R[2][0] + R[2][2];
|
77 |
+
|
78 |
+
float r4x = R[0][1] + R[0][2];
|
79 |
+
float r4y = R[1][1] + R[1][2];
|
80 |
+
float r4z = R[2][1] + R[2][2];
|
81 |
+
|
82 |
+
// dense matrix multiplication one column at a time
|
83 |
+
|
84 |
+
// column 0
|
85 |
+
float sh0_x = sh0 * R[0][0];
|
86 |
+
float sh0_y = sh0 * R[1][0];
|
87 |
+
float d0 = sh0_x * R[1][0];
|
88 |
+
float d1 = sh0_y * R[2][0];
|
89 |
+
float d2 = sh0 * (R[2][0] * R[2][0] + s_c4_div_c3);
|
90 |
+
float d3 = sh0_x * R[2][0];
|
91 |
+
float d4 = sh0_x * R[0][0] - sh0_y * R[1][0];
|
92 |
+
|
93 |
+
// column 1
|
94 |
+
float sh1_x = sh1 * R[0][2];
|
95 |
+
float sh1_y = sh1 * R[1][2];
|
96 |
+
d0 += sh1_x * R[1][2];
|
97 |
+
d1 += sh1_y * R[2][2];
|
98 |
+
d2 += sh1 * (R[2][2] * R[2][2] + s_c4_div_c3);
|
99 |
+
d3 += sh1_x * R[2][2];
|
100 |
+
d4 += sh1_x * R[0][2] - sh1_y * R[1][2];
|
101 |
+
|
102 |
+
// column 2
|
103 |
+
float sh2_x = sh2 * r2x;
|
104 |
+
float sh2_y = sh2 * r2y;
|
105 |
+
d0 += sh2_x * r2y;
|
106 |
+
d1 += sh2_y * r2z;
|
107 |
+
d2 += sh2 * (r2z * r2z + s_c4_div_c3_x2);
|
108 |
+
d3 += sh2_x * r2z;
|
109 |
+
d4 += sh2_x * r2x - sh2_y * r2y;
|
110 |
+
|
111 |
+
// column 3
|
112 |
+
float sh3_x = sh3 * r3x;
|
113 |
+
float sh3_y = sh3 * r3y;
|
114 |
+
d0 += sh3_x * r3y;
|
115 |
+
d1 += sh3_y * r3z;
|
116 |
+
d2 += sh3 * (r3z * r3z + s_c4_div_c3_x2);
|
117 |
+
d3 += sh3_x * r3z;
|
118 |
+
d4 += sh3_x * r3x - sh3_y * r3y;
|
119 |
+
|
120 |
+
// column 4
|
121 |
+
float sh4_x = sh4 * r4x;
|
122 |
+
float sh4_y = sh4 * r4y;
|
123 |
+
d0 += sh4_x * r4y;
|
124 |
+
d1 += sh4_y * r4z;
|
125 |
+
d2 += sh4 * (r4z * r4z + s_c4_div_c3_x2);
|
126 |
+
d3 += sh4_x * r4z;
|
127 |
+
d4 += sh4_x * r4x - sh4_y * r4y;
|
128 |
+
|
129 |
+
// extra multipliers
|
130 |
+
dst[0] = d0;
|
131 |
+
dst[1] = -d1;
|
132 |
+
dst[2] = d2 * s_scale_dst2;
|
133 |
+
dst[3] = -d3;
|
134 |
+
dst[4] = d4 * s_scale_dst4;
|
135 |
+
}
|
136 |
+
|
137 |
+
void main()
|
138 |
+
{
|
139 |
+
// normalization
|
140 |
+
mat3 R = mat3(ModelMat) * RotMat;
|
141 |
+
VertexOut.ModelNormal = a_Normal;
|
142 |
+
VertexOut.CameraNormal = (R * a_Normal);
|
143 |
+
VertexOut.Position = a_Position;
|
144 |
+
VertexOut.Texcoord = a_TextureCoord;
|
145 |
+
VertexOut.Tangent = (R * a_Tangent);
|
146 |
+
VertexOut.Bitangent = (R * a_Bitangent);
|
147 |
+
float PRT0, PRT1[3], PRT2[5];
|
148 |
+
PRT0 = a_PRT1[0];
|
149 |
+
PRT1[0] = a_PRT1[1];
|
150 |
+
PRT1[1] = a_PRT1[2];
|
151 |
+
PRT1[2] = a_PRT2[0];
|
152 |
+
PRT2[0] = a_PRT2[1];
|
153 |
+
PRT2[1] = a_PRT2[2];
|
154 |
+
PRT2[2] = a_PRT3[0];
|
155 |
+
PRT2[3] = a_PRT3[1];
|
156 |
+
PRT2[4] = a_PRT3[2];
|
157 |
+
|
158 |
+
OptRotateBand1(PRT1, R, PRT1);
|
159 |
+
OptRotateBand2(PRT2, R, PRT2);
|
160 |
+
|
161 |
+
VertexOut.PRT1 = vec3(PRT0,PRT1[0],PRT1[1]);
|
162 |
+
VertexOut.PRT2 = vec3(PRT1[2],PRT2[0],PRT2[1]);
|
163 |
+
VertexOut.PRT3 = vec3(PRT2[2],PRT2[3],PRT2[4]);
|
164 |
+
|
165 |
+
gl_Position = vec4(a_TextureCoord, 0.0, 1.0) - vec4(0.5, 0.5, 0, 0);
|
166 |
+
gl_Position[0] *= 2.0;
|
167 |
+
gl_Position[1] *= 2.0;
|
168 |
+
}
|
PIFu/lib/renderer/gl/data/quad.fs
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#version 330 core
|
2 |
+
out vec4 FragColor;
|
3 |
+
|
4 |
+
in vec2 TexCoord;
|
5 |
+
|
6 |
+
uniform sampler2D screenTexture;
|
7 |
+
|
8 |
+
void main()
|
9 |
+
{
|
10 |
+
FragColor = texture(screenTexture, TexCoord);
|
11 |
+
}
|
PIFu/lib/renderer/gl/data/quad.vs
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#version 330 core
|
2 |
+
layout (location = 0) in vec2 aPos;
|
3 |
+
layout (location = 1) in vec2 aTexCoord;
|
4 |
+
|
5 |
+
out vec2 TexCoord;
|
6 |
+
|
7 |
+
void main()
|
8 |
+
{
|
9 |
+
gl_Position = vec4(aPos.x, aPos.y, 0.0, 1.0);
|
10 |
+
TexCoord = aTexCoord;
|
11 |
+
}
|
PIFu/lib/renderer/gl/framework.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Mario Rosasco, 2016
|
2 |
+
# adapted from framework.cpp, Copyright (C) 2010-2012 by Jason L. McKesson
|
3 |
+
# This file is licensed under the MIT License.
|
4 |
+
#
|
5 |
+
# NB: Unlike in the framework.cpp organization, the main loop is contained
|
6 |
+
# in the tutorial files, not in this framework file. Additionally, a copy of
|
7 |
+
# this module file must exist in the same directory as the tutorial files
|
8 |
+
# to be imported properly.
|
9 |
+
|
10 |
+
import os
|
11 |
+
from OpenGL.GL import *
|
12 |
+
|
13 |
+
# Function that creates and compiles shaders according to the given type (a GL enum value) and
|
14 |
+
# shader program (a file containing a GLSL program).
|
15 |
+
def loadShader(shaderType, shaderFile):
|
16 |
+
# check if file exists, get full path name
|
17 |
+
strFilename = findFileOrThrow(shaderFile)
|
18 |
+
shaderData = None
|
19 |
+
with open(strFilename, 'r') as f:
|
20 |
+
shaderData = f.read()
|
21 |
+
|
22 |
+
shader = glCreateShader(shaderType)
|
23 |
+
glShaderSource(shader, shaderData) # note that this is a simpler function call than in C
|
24 |
+
|
25 |
+
# This shader compilation is more explicit than the one used in
|
26 |
+
# framework.cpp, which relies on a glutil wrapper function.
|
27 |
+
# This is made explicit here mainly to decrease dependence on pyOpenGL
|
28 |
+
# utilities and wrappers, which docs caution may change in future versions.
|
29 |
+
glCompileShader(shader)
|
30 |
+
|
31 |
+
status = glGetShaderiv(shader, GL_COMPILE_STATUS)
|
32 |
+
if status == GL_FALSE:
|
33 |
+
# Note that getting the error log is much simpler in Python than in C/C++
|
34 |
+
# and does not require explicit handling of the string buffer
|
35 |
+
strInfoLog = glGetShaderInfoLog(shader)
|
36 |
+
strShaderType = ""
|
37 |
+
if shaderType is GL_VERTEX_SHADER:
|
38 |
+
strShaderType = "vertex"
|
39 |
+
elif shaderType is GL_GEOMETRY_SHADER:
|
40 |
+
strShaderType = "geometry"
|
41 |
+
elif shaderType is GL_FRAGMENT_SHADER:
|
42 |
+
strShaderType = "fragment"
|
43 |
+
|
44 |
+
print("Compilation failure for " + strShaderType + " shader:\n" + str(strInfoLog))
|
45 |
+
|
46 |
+
return shader
|
47 |
+
|
48 |
+
|
49 |
+
# Function that accepts a list of shaders, compiles them, and returns a handle to the compiled program
|
50 |
+
def createProgram(shaderList):
|
51 |
+
program = glCreateProgram()
|
52 |
+
|
53 |
+
for shader in shaderList:
|
54 |
+
glAttachShader(program, shader)
|
55 |
+
|
56 |
+
glLinkProgram(program)
|
57 |
+
|
58 |
+
status = glGetProgramiv(program, GL_LINK_STATUS)
|
59 |
+
if status == GL_FALSE:
|
60 |
+
# Note that getting the error log is much simpler in Python than in C/C++
|
61 |
+
# and does not require explicit handling of the string buffer
|
62 |
+
strInfoLog = glGetProgramInfoLog(program)
|
63 |
+
print("Linker failure: \n" + str(strInfoLog))
|
64 |
+
|
65 |
+
for shader in shaderList:
|
66 |
+
glDetachShader(program, shader)
|
67 |
+
|
68 |
+
return program
|
69 |
+
|
70 |
+
|
71 |
+
# Helper function to locate and open the target file (passed in as a string).
|
72 |
+
# Returns the full path to the file as a string.
|
73 |
+
def findFileOrThrow(strBasename):
|
74 |
+
# Keep constant names in C-style convention, for readability
|
75 |
+
# when comparing to C(/C++) code.
|
76 |
+
if os.path.isfile(strBasename):
|
77 |
+
return strBasename
|
78 |
+
|
79 |
+
LOCAL_FILE_DIR = "data" + os.sep
|
80 |
+
GLOBAL_FILE_DIR = os.path.dirname(os.path.abspath(__file__)) + os.sep + "data" + os.sep
|
81 |
+
|
82 |
+
strFilename = LOCAL_FILE_DIR + strBasename
|
83 |
+
if os.path.isfile(strFilename):
|
84 |
+
return strFilename
|
85 |
+
|
86 |
+
strFilename = GLOBAL_FILE_DIR + strBasename
|
87 |
+
if os.path.isfile(strFilename):
|
88 |
+
return strFilename
|
89 |
+
|
90 |
+
raise IOError('Could not find target file ' + strBasename)
|
PIFu/lib/renderer/gl/glcontext.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Headless GPU-accelerated OpenGL context creation on Google Colaboratory.
|
2 |
+
|
3 |
+
Typical usage:
|
4 |
+
|
5 |
+
# Optional PyOpenGL configuratiopn can be done here.
|
6 |
+
# import OpenGL
|
7 |
+
# OpenGL.ERROR_CHECKING = True
|
8 |
+
|
9 |
+
# 'glcontext' must be imported before any OpenGL.* API.
|
10 |
+
from lucid.misc.gl.glcontext import create_opengl_context
|
11 |
+
|
12 |
+
# Now it's safe to import OpenGL and EGL functions
|
13 |
+
import OpenGL.GL as gl
|
14 |
+
|
15 |
+
# create_opengl_context() creates a GL context that is attached to an
|
16 |
+
# offscreen surface of the specified size. Note that rendering to buffers
|
17 |
+
# of other sizes and formats is still possible with OpenGL Framebuffers.
|
18 |
+
#
|
19 |
+
# Users are expected to directly use the EGL API in case more advanced
|
20 |
+
# context management is required.
|
21 |
+
width, height = 640, 480
|
22 |
+
create_opengl_context((width, height))
|
23 |
+
|
24 |
+
# OpenGL context is available here.
|
25 |
+
|
26 |
+
"""
|
27 |
+
|
28 |
+
from __future__ import print_function
|
29 |
+
|
30 |
+
# pylint: disable=unused-import,g-import-not-at-top,g-statement-before-imports
|
31 |
+
|
32 |
+
try:
|
33 |
+
import OpenGL
|
34 |
+
except:
|
35 |
+
print('This module depends on PyOpenGL.')
|
36 |
+
print('Please run "\033[1m!pip install -q pyopengl\033[0m" '
|
37 |
+
'prior importing this module.')
|
38 |
+
raise
|
39 |
+
|
40 |
+
import ctypes
|
41 |
+
from ctypes import pointer, util
|
42 |
+
import os
|
43 |
+
|
44 |
+
os.environ['PYOPENGL_PLATFORM'] = 'egl'
|
45 |
+
|
46 |
+
# OpenGL loading workaround.
|
47 |
+
#
|
48 |
+
# * PyOpenGL tries to load libGL, but we need libOpenGL, see [1,2].
|
49 |
+
# This could have been solved by a symlink libGL->libOpenGL, but:
|
50 |
+
#
|
51 |
+
# * Python 2.7 can't find libGL and linEGL due to a bug (see [3])
|
52 |
+
# in ctypes.util, that was only wixed in Python 3.6.
|
53 |
+
#
|
54 |
+
# So, the only solution I've found is to monkeypatch ctypes.util
|
55 |
+
# [1] https://devblogs.nvidia.com/egl-eye-opengl-visualization-without-x-server/
|
56 |
+
# [2] https://devblogs.nvidia.com/linking-opengl-server-side-rendering/
|
57 |
+
# [3] https://bugs.python.org/issue9998
|
58 |
+
_find_library_old = ctypes.util.find_library
|
59 |
+
try:
|
60 |
+
|
61 |
+
def _find_library_new(name):
|
62 |
+
return {
|
63 |
+
'GL': 'libOpenGL.so',
|
64 |
+
'EGL': 'libEGL.so',
|
65 |
+
}.get(name, _find_library_old(name))
|
66 |
+
util.find_library = _find_library_new
|
67 |
+
import OpenGL.GL as gl
|
68 |
+
import OpenGL.EGL as egl
|
69 |
+
from OpenGL import error
|
70 |
+
from OpenGL.EGL.EXT.device_base import egl_get_devices
|
71 |
+
from OpenGL.raw.EGL.EXT.platform_device import EGL_PLATFORM_DEVICE_EXT
|
72 |
+
except:
|
73 |
+
print('Unable to load OpenGL libraries. '
|
74 |
+
'Make sure you use GPU-enabled backend.')
|
75 |
+
print('Press "Runtime->Change runtime type" and set '
|
76 |
+
'"Hardware accelerator" to GPU.')
|
77 |
+
raise
|
78 |
+
finally:
|
79 |
+
util.find_library = _find_library_old
|
80 |
+
|
81 |
+
def create_initialized_headless_egl_display():
|
82 |
+
"""Creates an initialized EGL display directly on a device."""
|
83 |
+
for device in egl_get_devices():
|
84 |
+
display = egl.eglGetPlatformDisplayEXT(EGL_PLATFORM_DEVICE_EXT, device, None)
|
85 |
+
|
86 |
+
if display != egl.EGL_NO_DISPLAY and egl.eglGetError() == egl.EGL_SUCCESS:
|
87 |
+
# `eglInitialize` may or may not raise an exception on failure depending
|
88 |
+
# on how PyOpenGL is configured. We therefore catch a `GLError` and also
|
89 |
+
# manually check the output of `eglGetError()` here.
|
90 |
+
try:
|
91 |
+
initialized = egl.eglInitialize(display, None, None)
|
92 |
+
except error.GLError:
|
93 |
+
pass
|
94 |
+
else:
|
95 |
+
if initialized == egl.EGL_TRUE and egl.eglGetError() == egl.EGL_SUCCESS:
|
96 |
+
return display
|
97 |
+
return egl.EGL_NO_DISPLAY
|
98 |
+
|
99 |
+
def create_opengl_context(surface_size=(640, 480)):
|
100 |
+
"""Create offscreen OpenGL context and make it current.
|
101 |
+
|
102 |
+
Users are expected to directly use EGL API in case more advanced
|
103 |
+
context management is required.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
surface_size: (width, height), size of the offscreen rendering surface.
|
107 |
+
"""
|
108 |
+
egl_display = create_initialized_headless_egl_display()
|
109 |
+
if egl_display == egl.EGL_NO_DISPLAY:
|
110 |
+
raise ImportError('Cannot initialize a headless EGL display.')
|
111 |
+
|
112 |
+
major, minor = egl.EGLint(), egl.EGLint()
|
113 |
+
egl.eglInitialize(egl_display, pointer(major), pointer(minor))
|
114 |
+
|
115 |
+
config_attribs = [
|
116 |
+
egl.EGL_SURFACE_TYPE, egl.EGL_PBUFFER_BIT, egl.EGL_BLUE_SIZE, 8,
|
117 |
+
egl.EGL_GREEN_SIZE, 8, egl.EGL_RED_SIZE, 8, egl.EGL_DEPTH_SIZE, 24,
|
118 |
+
egl.EGL_RENDERABLE_TYPE, egl.EGL_OPENGL_BIT, egl.EGL_NONE
|
119 |
+
]
|
120 |
+
config_attribs = (egl.EGLint * len(config_attribs))(*config_attribs)
|
121 |
+
|
122 |
+
num_configs = egl.EGLint()
|
123 |
+
egl_cfg = egl.EGLConfig()
|
124 |
+
egl.eglChooseConfig(egl_display, config_attribs, pointer(egl_cfg), 1,
|
125 |
+
pointer(num_configs))
|
126 |
+
|
127 |
+
width, height = surface_size
|
128 |
+
pbuffer_attribs = [
|
129 |
+
egl.EGL_WIDTH,
|
130 |
+
width,
|
131 |
+
egl.EGL_HEIGHT,
|
132 |
+
height,
|
133 |
+
egl.EGL_NONE,
|
134 |
+
]
|
135 |
+
pbuffer_attribs = (egl.EGLint * len(pbuffer_attribs))(*pbuffer_attribs)
|
136 |
+
egl_surf = egl.eglCreatePbufferSurface(egl_display, egl_cfg, pbuffer_attribs)
|
137 |
+
|
138 |
+
egl.eglBindAPI(egl.EGL_OPENGL_API)
|
139 |
+
|
140 |
+
egl_context = egl.eglCreateContext(egl_display, egl_cfg, egl.EGL_NO_CONTEXT,
|
141 |
+
None)
|
142 |
+
egl.eglMakeCurrent(egl_display, egl_surf, egl_surf, egl_context)
|
PIFu/lib/renderer/gl/init_gl.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_glut_window = None
|
2 |
+
_context_inited = None
|
3 |
+
|
4 |
+
def initialize_GL_context(width=512, height=512, egl=False):
|
5 |
+
'''
|
6 |
+
default context uses GLUT
|
7 |
+
'''
|
8 |
+
if not egl:
|
9 |
+
import OpenGL.GLUT as GLUT
|
10 |
+
display_mode = GLUT.GLUT_DOUBLE | GLUT.GLUT_RGB | GLUT.GLUT_DEPTH
|
11 |
+
global _glut_window
|
12 |
+
if _glut_window is None:
|
13 |
+
GLUT.glutInit()
|
14 |
+
GLUT.glutInitDisplayMode(display_mode)
|
15 |
+
GLUT.glutInitWindowSize(width, height)
|
16 |
+
GLUT.glutInitWindowPosition(0, 0)
|
17 |
+
_glut_window = GLUT.glutCreateWindow("My Render.")
|
18 |
+
else:
|
19 |
+
from .glcontext import create_opengl_context
|
20 |
+
global _context_inited
|
21 |
+
if _context_inited is None:
|
22 |
+
create_opengl_context((width, height))
|
23 |
+
_context_inited = True
|
24 |
+
|
PIFu/lib/renderer/gl/prt_render.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
|
4 |
+
from .framework import *
|
5 |
+
from .cam_render import CamRender
|
6 |
+
|
7 |
+
class PRTRender(CamRender):
|
8 |
+
def __init__(self, width=1600, height=1200, name='PRT Renderer', uv_mode=False, ms_rate=1, egl=False):
|
9 |
+
program_files = ['prt.vs', 'prt.fs'] if not uv_mode else ['prt_uv.vs', 'prt_uv.fs']
|
10 |
+
CamRender.__init__(self, width, height, name, program_files=program_files, color_size=8, ms_rate=ms_rate, egl=egl)
|
11 |
+
|
12 |
+
# WARNING: this differs from vertex_buffer and vertex_data in Render
|
13 |
+
self.vert_buffer = {}
|
14 |
+
self.vert_data = {}
|
15 |
+
|
16 |
+
self.norm_buffer = {}
|
17 |
+
self.norm_data = {}
|
18 |
+
|
19 |
+
self.tan_buffer = {}
|
20 |
+
self.tan_data = {}
|
21 |
+
|
22 |
+
self.btan_buffer = {}
|
23 |
+
self.btan_data = {}
|
24 |
+
|
25 |
+
self.prt1_buffer = {}
|
26 |
+
self.prt1_data = {}
|
27 |
+
self.prt2_buffer = {}
|
28 |
+
self.prt2_data = {}
|
29 |
+
self.prt3_buffer = {}
|
30 |
+
self.prt3_data = {}
|
31 |
+
|
32 |
+
self.uv_buffer = {}
|
33 |
+
self.uv_data = {}
|
34 |
+
|
35 |
+
self.render_texture_mat = {}
|
36 |
+
|
37 |
+
self.vertex_dim = {}
|
38 |
+
self.n_vertices = {}
|
39 |
+
|
40 |
+
self.norm_mat_unif = glGetUniformLocation(self.program, 'NormMat')
|
41 |
+
self.normalize_matrix = np.eye(4)
|
42 |
+
|
43 |
+
self.shcoeff_unif = glGetUniformLocation(self.program, 'SHCoeffs')
|
44 |
+
self.shcoeffs = np.zeros((9,3))
|
45 |
+
self.shcoeffs[0,:] = 1.0
|
46 |
+
#self.shcoeffs[1:,:] = np.random.rand(8,3)
|
47 |
+
|
48 |
+
self.hasAlbedoUnif = glGetUniformLocation(self.program, 'hasAlbedoMap')
|
49 |
+
self.hasNormalUnif = glGetUniformLocation(self.program, 'hasNormalMap')
|
50 |
+
|
51 |
+
self.analyticUnif = glGetUniformLocation(self.program, 'analytic')
|
52 |
+
self.analytic = False
|
53 |
+
|
54 |
+
self.rot_mat_unif = glGetUniformLocation(self.program, 'RotMat')
|
55 |
+
self.rot_matrix = np.eye(3)
|
56 |
+
|
57 |
+
def set_texture(self, mat_name, smplr_name, texture):
|
58 |
+
# texture_image: H x W x 3
|
59 |
+
width = texture.shape[1]
|
60 |
+
height = texture.shape[0]
|
61 |
+
texture = np.flip(texture, 0)
|
62 |
+
img_data = np.fromstring(texture.tostring(), np.uint8)
|
63 |
+
|
64 |
+
if mat_name not in self.render_texture_mat:
|
65 |
+
self.render_texture_mat[mat_name] = {}
|
66 |
+
if smplr_name in self.render_texture_mat[mat_name].keys():
|
67 |
+
glDeleteTextures([self.render_texture_mat[mat_name][smplr_name]])
|
68 |
+
del self.render_texture_mat[mat_name][smplr_name]
|
69 |
+
self.render_texture_mat[mat_name][smplr_name] = glGenTextures(1)
|
70 |
+
glActiveTexture(GL_TEXTURE0)
|
71 |
+
|
72 |
+
glPixelStorei(GL_UNPACK_ALIGNMENT, 1)
|
73 |
+
glBindTexture(GL_TEXTURE_2D, self.render_texture_mat[mat_name][smplr_name])
|
74 |
+
|
75 |
+
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB, width, height, 0, GL_RGB, GL_UNSIGNED_BYTE, img_data)
|
76 |
+
|
77 |
+
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAX_LEVEL, 3)
|
78 |
+
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE)
|
79 |
+
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE)
|
80 |
+
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR)
|
81 |
+
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR_MIPMAP_LINEAR)
|
82 |
+
|
83 |
+
glGenerateMipmap(GL_TEXTURE_2D)
|
84 |
+
|
85 |
+
def set_albedo(self, texture_image, mat_name='all'):
|
86 |
+
self.set_texture(mat_name, 'AlbedoMap', texture_image)
|
87 |
+
|
88 |
+
def set_normal_map(self, texture_image, mat_name='all'):
|
89 |
+
self.set_texture(mat_name, 'NormalMap', texture_image)
|
90 |
+
|
91 |
+
def set_mesh(self, vertices, faces, norms, faces_nml, uvs, faces_uvs, prt, faces_prt, tans, bitans, mat_name='all'):
|
92 |
+
self.vert_data[mat_name] = vertices[faces.reshape([-1])]
|
93 |
+
self.n_vertices[mat_name] = self.vert_data[mat_name].shape[0]
|
94 |
+
self.vertex_dim[mat_name] = self.vert_data[mat_name].shape[1]
|
95 |
+
|
96 |
+
if mat_name not in self.vert_buffer.keys():
|
97 |
+
self.vert_buffer[mat_name] = glGenBuffers(1)
|
98 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.vert_buffer[mat_name])
|
99 |
+
glBufferData(GL_ARRAY_BUFFER, self.vert_data[mat_name], GL_STATIC_DRAW)
|
100 |
+
|
101 |
+
self.uv_data[mat_name] = uvs[faces_uvs.reshape([-1])]
|
102 |
+
if mat_name not in self.uv_buffer.keys():
|
103 |
+
self.uv_buffer[mat_name] = glGenBuffers(1)
|
104 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.uv_buffer[mat_name])
|
105 |
+
glBufferData(GL_ARRAY_BUFFER, self.uv_data[mat_name], GL_STATIC_DRAW)
|
106 |
+
|
107 |
+
self.norm_data[mat_name] = norms[faces_nml.reshape([-1])]
|
108 |
+
if mat_name not in self.norm_buffer.keys():
|
109 |
+
self.norm_buffer[mat_name] = glGenBuffers(1)
|
110 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.norm_buffer[mat_name])
|
111 |
+
glBufferData(GL_ARRAY_BUFFER, self.norm_data[mat_name], GL_STATIC_DRAW)
|
112 |
+
|
113 |
+
self.tan_data[mat_name] = tans[faces_nml.reshape([-1])]
|
114 |
+
if mat_name not in self.tan_buffer.keys():
|
115 |
+
self.tan_buffer[mat_name] = glGenBuffers(1)
|
116 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.tan_buffer[mat_name])
|
117 |
+
glBufferData(GL_ARRAY_BUFFER, self.tan_data[mat_name], GL_STATIC_DRAW)
|
118 |
+
|
119 |
+
self.btan_data[mat_name] = bitans[faces_nml.reshape([-1])]
|
120 |
+
if mat_name not in self.btan_buffer.keys():
|
121 |
+
self.btan_buffer[mat_name] = glGenBuffers(1)
|
122 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.btan_buffer[mat_name])
|
123 |
+
glBufferData(GL_ARRAY_BUFFER, self.btan_data[mat_name], GL_STATIC_DRAW)
|
124 |
+
|
125 |
+
self.prt1_data[mat_name] = prt[faces_prt.reshape([-1])][:,:3]
|
126 |
+
self.prt2_data[mat_name] = prt[faces_prt.reshape([-1])][:,3:6]
|
127 |
+
self.prt3_data[mat_name] = prt[faces_prt.reshape([-1])][:,6:]
|
128 |
+
|
129 |
+
if mat_name not in self.prt1_buffer.keys():
|
130 |
+
self.prt1_buffer[mat_name] = glGenBuffers(1)
|
131 |
+
if mat_name not in self.prt2_buffer.keys():
|
132 |
+
self.prt2_buffer[mat_name] = glGenBuffers(1)
|
133 |
+
if mat_name not in self.prt3_buffer.keys():
|
134 |
+
self.prt3_buffer[mat_name] = glGenBuffers(1)
|
135 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.prt1_buffer[mat_name])
|
136 |
+
glBufferData(GL_ARRAY_BUFFER, self.prt1_data[mat_name], GL_STATIC_DRAW)
|
137 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.prt2_buffer[mat_name])
|
138 |
+
glBufferData(GL_ARRAY_BUFFER, self.prt2_data[mat_name], GL_STATIC_DRAW)
|
139 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.prt3_buffer[mat_name])
|
140 |
+
glBufferData(GL_ARRAY_BUFFER, self.prt3_data[mat_name], GL_STATIC_DRAW)
|
141 |
+
|
142 |
+
glBindBuffer(GL_ARRAY_BUFFER, 0)
|
143 |
+
|
144 |
+
def set_mesh_mtl(self, vertices, faces, norms, faces_nml, uvs, faces_uvs, tans, bitans, prt):
|
145 |
+
for key in faces:
|
146 |
+
self.vert_data[key] = vertices[faces[key].reshape([-1])]
|
147 |
+
self.n_vertices[key] = self.vert_data[key].shape[0]
|
148 |
+
self.vertex_dim[key] = self.vert_data[key].shape[1]
|
149 |
+
|
150 |
+
if key not in self.vert_buffer.keys():
|
151 |
+
self.vert_buffer[key] = glGenBuffers(1)
|
152 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.vert_buffer[key])
|
153 |
+
glBufferData(GL_ARRAY_BUFFER, self.vert_data[key], GL_STATIC_DRAW)
|
154 |
+
|
155 |
+
self.uv_data[key] = uvs[faces_uvs[key].reshape([-1])]
|
156 |
+
if key not in self.uv_buffer.keys():
|
157 |
+
self.uv_buffer[key] = glGenBuffers(1)
|
158 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.uv_buffer[key])
|
159 |
+
glBufferData(GL_ARRAY_BUFFER, self.uv_data[key], GL_STATIC_DRAW)
|
160 |
+
|
161 |
+
self.norm_data[key] = norms[faces_nml[key].reshape([-1])]
|
162 |
+
if key not in self.norm_buffer.keys():
|
163 |
+
self.norm_buffer[key] = glGenBuffers(1)
|
164 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.norm_buffer[key])
|
165 |
+
glBufferData(GL_ARRAY_BUFFER, self.norm_data[key], GL_STATIC_DRAW)
|
166 |
+
|
167 |
+
self.tan_data[key] = tans[faces_nml[key].reshape([-1])]
|
168 |
+
if key not in self.tan_buffer.keys():
|
169 |
+
self.tan_buffer[key] = glGenBuffers(1)
|
170 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.tan_buffer[key])
|
171 |
+
glBufferData(GL_ARRAY_BUFFER, self.tan_data[key], GL_STATIC_DRAW)
|
172 |
+
|
173 |
+
self.btan_data[key] = bitans[faces_nml[key].reshape([-1])]
|
174 |
+
if key not in self.btan_buffer.keys():
|
175 |
+
self.btan_buffer[key] = glGenBuffers(1)
|
176 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.btan_buffer[key])
|
177 |
+
glBufferData(GL_ARRAY_BUFFER, self.btan_data[key], GL_STATIC_DRAW)
|
178 |
+
|
179 |
+
self.prt1_data[key] = prt[faces[key].reshape([-1])][:,:3]
|
180 |
+
self.prt2_data[key] = prt[faces[key].reshape([-1])][:,3:6]
|
181 |
+
self.prt3_data[key] = prt[faces[key].reshape([-1])][:,6:]
|
182 |
+
|
183 |
+
if key not in self.prt1_buffer.keys():
|
184 |
+
self.prt1_buffer[key] = glGenBuffers(1)
|
185 |
+
if key not in self.prt2_buffer.keys():
|
186 |
+
self.prt2_buffer[key] = glGenBuffers(1)
|
187 |
+
if key not in self.prt3_buffer.keys():
|
188 |
+
self.prt3_buffer[key] = glGenBuffers(1)
|
189 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.prt1_buffer[key])
|
190 |
+
glBufferData(GL_ARRAY_BUFFER, self.prt1_data[key], GL_STATIC_DRAW)
|
191 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.prt2_buffer[key])
|
192 |
+
glBufferData(GL_ARRAY_BUFFER, self.prt2_data[key], GL_STATIC_DRAW)
|
193 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.prt3_buffer[key])
|
194 |
+
glBufferData(GL_ARRAY_BUFFER, self.prt3_data[key], GL_STATIC_DRAW)
|
195 |
+
|
196 |
+
glBindBuffer(GL_ARRAY_BUFFER, 0)
|
197 |
+
|
198 |
+
def cleanup(self):
|
199 |
+
|
200 |
+
glBindBuffer(GL_ARRAY_BUFFER, 0)
|
201 |
+
for key in self.vert_data:
|
202 |
+
glDeleteBuffers(1, [self.vert_buffer[key]])
|
203 |
+
glDeleteBuffers(1, [self.norm_buffer[key]])
|
204 |
+
glDeleteBuffers(1, [self.uv_buffer[key]])
|
205 |
+
|
206 |
+
glDeleteBuffers(1, [self.tan_buffer[key]])
|
207 |
+
glDeleteBuffers(1, [self.btan_buffer[key]])
|
208 |
+
glDeleteBuffers(1, [self.prt1_buffer[key]])
|
209 |
+
glDeleteBuffers(1, [self.prt2_buffer[key]])
|
210 |
+
glDeleteBuffers(1, [self.prt3_buffer[key]])
|
211 |
+
|
212 |
+
glDeleteBuffers(1, [])
|
213 |
+
|
214 |
+
for smplr in self.render_texture_mat[key]:
|
215 |
+
glDeleteTextures([self.render_texture_mat[key][smplr]])
|
216 |
+
|
217 |
+
self.vert_buffer = {}
|
218 |
+
self.vert_data = {}
|
219 |
+
|
220 |
+
self.norm_buffer = {}
|
221 |
+
self.norm_data = {}
|
222 |
+
|
223 |
+
self.tan_buffer = {}
|
224 |
+
self.tan_data = {}
|
225 |
+
|
226 |
+
self.btan_buffer = {}
|
227 |
+
self.btan_data = {}
|
228 |
+
|
229 |
+
self.prt1_buffer = {}
|
230 |
+
self.prt1_data = {}
|
231 |
+
|
232 |
+
self.prt2_buffer = {}
|
233 |
+
self.prt2_data = {}
|
234 |
+
|
235 |
+
self.prt3_buffer = {}
|
236 |
+
self.prt3_data = {}
|
237 |
+
|
238 |
+
self.uv_buffer = {}
|
239 |
+
self.uv_data = {}
|
240 |
+
|
241 |
+
self.render_texture_mat = {}
|
242 |
+
|
243 |
+
self.vertex_dim = {}
|
244 |
+
self.n_vertices = {}
|
245 |
+
|
246 |
+
def randomize_sh(self):
|
247 |
+
self.shcoeffs[0,:] = 0.8
|
248 |
+
self.shcoeffs[1:,:] = 1.0*np.random.rand(8,3)
|
249 |
+
|
250 |
+
def set_sh(self, sh):
|
251 |
+
self.shcoeffs = sh
|
252 |
+
|
253 |
+
def set_norm_mat(self, scale, center):
|
254 |
+
N = np.eye(4)
|
255 |
+
N[:3, :3] = scale*np.eye(3)
|
256 |
+
N[:3, 3] = -scale*center
|
257 |
+
|
258 |
+
self.normalize_matrix = N
|
259 |
+
|
260 |
+
def draw(self):
|
261 |
+
self.draw_init()
|
262 |
+
|
263 |
+
glDisable(GL_BLEND)
|
264 |
+
#glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)
|
265 |
+
glEnable(GL_MULTISAMPLE)
|
266 |
+
|
267 |
+
glUseProgram(self.program)
|
268 |
+
glUniformMatrix4fv(self.norm_mat_unif, 1, GL_FALSE, self.normalize_matrix.transpose())
|
269 |
+
glUniformMatrix4fv(self.model_mat_unif, 1, GL_FALSE, self.model_view_matrix.transpose())
|
270 |
+
glUniformMatrix4fv(self.persp_mat_unif, 1, GL_FALSE, self.projection_matrix.transpose())
|
271 |
+
|
272 |
+
if 'AlbedoMap' in self.render_texture_mat['all']:
|
273 |
+
glUniform1ui(self.hasAlbedoUnif, GLuint(1))
|
274 |
+
else:
|
275 |
+
glUniform1ui(self.hasAlbedoUnif, GLuint(0))
|
276 |
+
|
277 |
+
if 'NormalMap' in self.render_texture_mat['all']:
|
278 |
+
glUniform1ui(self.hasNormalUnif, GLuint(1))
|
279 |
+
else:
|
280 |
+
glUniform1ui(self.hasNormalUnif, GLuint(0))
|
281 |
+
|
282 |
+
glUniform1ui(self.analyticUnif, GLuint(1) if self.analytic else GLuint(0))
|
283 |
+
|
284 |
+
glUniform3fv(self.shcoeff_unif, 9, self.shcoeffs)
|
285 |
+
|
286 |
+
glUniformMatrix3fv(self.rot_mat_unif, 1, GL_FALSE, self.rot_matrix.transpose())
|
287 |
+
|
288 |
+
for mat in self.vert_buffer:
|
289 |
+
# Handle vertex buffer
|
290 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.vert_buffer[mat])
|
291 |
+
glEnableVertexAttribArray(0)
|
292 |
+
glVertexAttribPointer(0, self.vertex_dim[mat], GL_DOUBLE, GL_FALSE, 0, None)
|
293 |
+
|
294 |
+
# Handle normal buffer
|
295 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.norm_buffer[mat])
|
296 |
+
glEnableVertexAttribArray(1)
|
297 |
+
glVertexAttribPointer(1, 3, GL_DOUBLE, GL_FALSE, 0, None)
|
298 |
+
|
299 |
+
# Handle uv buffer
|
300 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.uv_buffer[mat])
|
301 |
+
glEnableVertexAttribArray(2)
|
302 |
+
glVertexAttribPointer(2, 2, GL_DOUBLE, GL_FALSE, 0, None)
|
303 |
+
|
304 |
+
# Handle tan buffer
|
305 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.tan_buffer[mat])
|
306 |
+
glEnableVertexAttribArray(3)
|
307 |
+
glVertexAttribPointer(3, 3, GL_DOUBLE, GL_FALSE, 0, None)
|
308 |
+
|
309 |
+
# Handle btan buffer
|
310 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.btan_buffer[mat])
|
311 |
+
glEnableVertexAttribArray(4)
|
312 |
+
glVertexAttribPointer(4, 3, GL_DOUBLE, GL_FALSE, 0, None)
|
313 |
+
|
314 |
+
# Handle PTR buffer
|
315 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.prt1_buffer[mat])
|
316 |
+
glEnableVertexAttribArray(5)
|
317 |
+
glVertexAttribPointer(5, 3, GL_DOUBLE, GL_FALSE, 0, None)
|
318 |
+
|
319 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.prt2_buffer[mat])
|
320 |
+
glEnableVertexAttribArray(6)
|
321 |
+
glVertexAttribPointer(6, 3, GL_DOUBLE, GL_FALSE, 0, None)
|
322 |
+
|
323 |
+
glBindBuffer(GL_ARRAY_BUFFER, self.prt3_buffer[mat])
|
324 |
+
glEnableVertexAttribArray(7)
|
325 |
+
glVertexAttribPointer(7, 3, GL_DOUBLE, GL_FALSE, 0, None)
|
326 |
+
|
327 |
+
for i, smplr in enumerate(self.render_texture_mat[mat]):
|
328 |
+
glActiveTexture(GL_TEXTURE0 + i)
|
329 |
+
glBindTexture(GL_TEXTURE_2D, self.render_texture_mat[mat][smplr])
|
330 |
+
glUniform1i(glGetUniformLocation(self.program, smplr), i)
|
331 |
+
|
332 |
+
glDrawArrays(GL_TRIANGLES, 0, self.n_vertices[mat])
|
333 |
+
|
334 |
+
glDisableVertexAttribArray(7)
|
335 |
+
glDisableVertexAttribArray(6)
|
336 |
+
glDisableVertexAttribArray(5)
|
337 |
+
glDisableVertexAttribArray(4)
|
338 |
+
glDisableVertexAttribArray(3)
|
339 |
+
glDisableVertexAttribArray(2)
|
340 |
+
glDisableVertexAttribArray(1)
|
341 |
+
glDisableVertexAttribArray(0)
|
342 |
+
|
343 |
+
glBindBuffer(GL_ARRAY_BUFFER, 0)
|
344 |
+
|
345 |
+
glUseProgram(0)
|
346 |
+
|
347 |
+
glDisable(GL_BLEND)
|
348 |
+
glDisable(GL_MULTISAMPLE)
|
349 |
+
|
350 |
+
self.draw_end()
|