qninhdt commited on
Commit
1ab03a3
·
verified ·
1 Parent(s): 8d6fb7c

Upload 68 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ datasets/celeba_anno/list_attr_celeba.txt filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 VISTEC - Vidyasirimedhi Institute of Science and Technology
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Official implementation of Diffusion Autoencoders
2
+
3
+ A CVPR 2022 (ORAL) paper ([paper](https://openaccess.thecvf.com/content/CVPR2022/html/Preechakul_Diffusion_Autoencoders_Toward_a_Meaningful_and_Decodable_Representation_CVPR_2022_paper.html), [site](https://diff-ae.github.io/), [5-min video](https://youtu.be/i3rjEsiHoUU)):
4
+
5
+ ```
6
+ @inproceedings{preechakul2021diffusion,
7
+ title={Diffusion Autoencoders: Toward a Meaningful and Decodable Representation},
8
+ author={Preechakul, Konpat and Chatthee, Nattanat and Wizadwongsa, Suttisak and Suwajanakorn, Supasorn},
9
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
10
+ year={2022},
11
+ }
12
+ ```
13
+
14
+ ## Usage
15
+
16
+ ⚙️ Try a Colab walkthrough: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/1OTfwkklN-IEd4hFk4LnweOleyDtS4XTh/view?usp=sharing)
17
+
18
+ 🤗 Try a web demo: [![Replicate](https://replicate.com/cjwbw/diffae/badge)](https://replicate.com/cjwbw/diffae)
19
+
20
+ Note: Since we expect a lot of changes on the codebase, please fork the repo before using.
21
+
22
+ ### Prerequisites
23
+
24
+ See `requirements.txt`
25
+
26
+ ```
27
+ pip install -r requirements.txt
28
+ ```
29
+
30
+ ### Quick start
31
+
32
+ A jupyter notebook.
33
+
34
+ For unconditional generation: `sample.ipynb`
35
+
36
+ For manipulation: `manipulate.ipynb`
37
+
38
+ For interpolation: `interpolate.ipynb`
39
+
40
+ For autoencoding: `autoencoding.ipynb`
41
+
42
+ Aligning your own images:
43
+
44
+ 1. Put images into the `imgs` directory
45
+ 2. Run `align.py` (need to `pip install dlib requests`)
46
+ 3. Result images will be available in `imgs_align` directory
47
+
48
+ <table>
49
+ <tr>
50
+ <th width="33%">
51
+ Original in <code>imgs</code> directory<br><img src="imgs/sandy.JPG" style="width: 100%">
52
+ </th>
53
+ <th width="33%">
54
+ Aligned with <code>align.py</code><br><img src="imgs_align/sandy.png" style="width: 100%">
55
+ </th>
56
+ <th width="33%">
57
+ Using <code>manipulate.ipynb</code><br><img src="imgs_manipulated/sandy-wavyhair.png" style="width: 100%">
58
+ </th>
59
+ </tr>
60
+ </table>
61
+
62
+
63
+ ### Checkpoints
64
+
65
+ We provide checkpoints for the following models:
66
+
67
+ 1. DDIM: **FFHQ128** ([72M](https://drive.google.com/drive/folders/1-fa46UPSgy9ximKngBflgSj3u87-DLrw), [130M](https://drive.google.com/drive/folders/1-Sqes07fs1y9sAYXuYWSoDE_xxTtH4yx)), [**Bedroom128**](https://drive.google.com/drive/folders/1-_8LZd5inoAOBT-hO5f7RYivt95FbYT1), [**Horse128**](https://drive.google.com/drive/folders/10Hq3zIlJs9ZSiXDQVYuVJVf0cX4a_nDB)
68
+ 2. DiffAE (autoencoding only): [**FFHQ256**](https://drive.google.com/drive/folders/1-5zfxT6Gl-GjxM7z9ZO2AHlB70tfmF6V), **FFHQ128** ([72M](https://drive.google.com/drive/folders/10bmB6WhLkgxybkhso5g3JmIFPAnmZMQO), [130M](https://drive.google.com/drive/folders/10UNtFNfxbHBPkoIh003JkSPto5s-VbeN)), [**Bedroom128**](https://drive.google.com/drive/folders/12EdjbIKnvP5RngKsR0UU-4kgpPAaYtlp), [**Horse128**](https://drive.google.com/drive/folders/12EtTRXzQc5uPHscpjIcci-Rg-OGa_N30)
69
+ 3. DiffAE (with latent DPM, can sample): [**FFHQ256**](https://drive.google.com/drive/folders/1-H8WzKc65dEONN-DQ87TnXc23nTXDTYb), [**FFHQ128**](https://drive.google.com/drive/folders/11pdjMQ6NS8GFFiGOq3fziNJxzXU1Mw3l), [**Bedroom128**](https://drive.google.com/drive/folders/11mdxv2lVX5Em8TuhNJt-Wt2XKt25y8zU), [**Horse128**](https://drive.google.com/drive/folders/11k8XNDK3ENxiRnPSUdJ4rnagJYo4uKEo)
70
+ 4. DiffAE's classifiers (for manipulation): [**FFHQ256's latent on CelebAHQ**](https://drive.google.com/drive/folders/117Wv7RZs_gumgrCOIhDEWgsNy6BRJorg), [**FFHQ128's latent on CelebAHQ**](https://drive.google.com/drive/folders/11EYIyuK6IX44C8MqreUyMgPCNiEnwhmI)
71
+
72
+ Checkpoints ought to be put into a separate directory `checkpoints`.
73
+ Download the checkpoints and put them into `checkpoints` directory. It should look like this:
74
+
75
+ ```
76
+ checkpoints/
77
+ - bedroom128_autoenc
78
+ - last.ckpt # diffae checkpoint
79
+ - latent.ckpt # predicted z_sem on the dataset
80
+ - bedroom128_autoenc_latent
81
+ - last.ckpt # diffae + latent DPM checkpoint
82
+ - bedroom128_ddpm
83
+ - ...
84
+ ```
85
+
86
+
87
+ ### LMDB Datasets
88
+
89
+ We do not own any of the following datasets. We provide the LMDB ready-to-use dataset for the sake of convenience.
90
+
91
+ - [FFHQ](https://1drv.ms/f/s!Ar2O0vx8sW70uLV1Ivk2pTjam1A8VA)
92
+ - [CelebAHQ](https://1drv.ms/f/s!Ar2O0vx8sW70uL4GMeWEciHkHdH6vQ)
93
+
94
+ **Broken links**
95
+
96
+ Note: I'm trying to recover the following links.
97
+
98
+ - [CelebA](https://drive.google.com/drive/folders/1HJAhK2hLYcT_n0gWlCu5XxdZj-bPekZ0?usp=sharing)
99
+ - [LSUN Bedroom](https://drive.google.com/drive/folders/1O_3aT3LtY1YDE2pOQCp6MFpCk7Pcpkhb?usp=sharing)
100
+ - [LSUN Horse](https://drive.google.com/drive/folders/1ooHW7VivZUs4i5CarPaWxakCwfeqAK8l?usp=sharing)
101
+
102
+ The directory tree should be:
103
+
104
+ ```
105
+ datasets/
106
+ - bedroom256.lmdb
107
+ - celebahq256.lmdb
108
+ - celeba.lmdb
109
+ - ffhq256.lmdb
110
+ - horse256.lmdb
111
+ ```
112
+
113
+ You can also download from the original sources, and use our provided codes to package them as LMDB files.
114
+ Original sources for each dataset is as follows:
115
+
116
+ - FFHQ (https://github.com/NVlabs/ffhq-dataset)
117
+ - CelebAHQ (https://github.com/switchablenorms/CelebAMask-HQ)
118
+ - CelebA (https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
119
+ - LSUN (https://github.com/fyu/lsun)
120
+
121
+ The conversion codes are provided as:
122
+
123
+ ```
124
+ data_resize_bedroom.py
125
+ data_resize_celebhq.py
126
+ data_resize_celeba.py
127
+ data_resize_ffhq.py
128
+ data_resize_horse.py
129
+ ```
130
+
131
+ Google drive: https://drive.google.com/drive/folders/1abNP4QKGbNnymjn8607BF0cwxX2L23jh?usp=sharing
132
+
133
+
134
+ ## Training
135
+
136
+ We provide scripts for training & evaluate DDIM and DiffAE (including latent DPM) on the following datasets: FFHQ128, FFHQ256, Bedroom128, Horse128, Celeba64 (D2C's crop).
137
+ Usually, the evaluation results (FID's) will be available in `eval` directory.
138
+
139
+ Note: Most experiment requires at least 4x V100s during training the DPM models while requiring 1x 2080Ti during training the accompanying latent DPM.
140
+
141
+
142
+
143
+ **FFHQ128**
144
+ ```
145
+ # diffae
146
+ python run_ffhq128.py
147
+ # ddim
148
+ python run_ffhq128_ddim.py
149
+ ```
150
+
151
+ A classifier (for manipulation) can be trained using:
152
+ ```
153
+ python run_ffhq128_cls.py
154
+ ```
155
+
156
+ **FFHQ256**
157
+
158
+ We only trained the DiffAE due to high computation cost.
159
+ This requires 8x V100s.
160
+ ```
161
+ sbatch run_ffhq256.py
162
+ ```
163
+
164
+ After the task is done, you need to train the latent DPM (requiring only 1x 2080Ti)
165
+ ```
166
+ python run_ffhq256_latent.py
167
+ ```
168
+
169
+ A classifier (for manipulation) can be trained using:
170
+ ```
171
+ python run_ffhq256_cls.py
172
+ ```
173
+
174
+ **Bedroom128**
175
+
176
+ ```
177
+ # diffae
178
+ python run_bedroom128.py
179
+ # ddim
180
+ python run_bedroom128_ddim.py
181
+ ```
182
+
183
+ **Horse128**
184
+
185
+ ```
186
+ # diffae
187
+ python run_horse128.py
188
+ # ddim
189
+ python run_horse128_ddim.py
190
+ ```
191
+
192
+ **Celeba64**
193
+
194
+ This experiment can be run on 2080Ti's.
195
+
196
+ ```
197
+ # diffae
198
+ python run_celeba64.py
199
+ ```
README.md.backup ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Official implementation of Diffusion Autoencoders
2
+
3
+ A CVPR 2022 paper:
4
+
5
+ > Preechakul, Konpat, Nattanat Chatthee, Suttisak Wizadwongsa, and Supasorn Suwajanakorn. 2021. “Diffusion Autoencoders: Toward a Meaningful and Decodable Representation.” arXiv [cs.CV]. arXiv. http://arxiv.org/abs/2111.15640.
6
+
7
+ ## Usage
8
+
9
+ Note: Since we expect a lot of changes on the codebase, please fork the repo before using.
10
+
11
+ ### Prerequisites
12
+
13
+ See `requirements.txt`
14
+
15
+ ```
16
+ pip install -r requirements.txt
17
+ ```
18
+
19
+ ### Quick start
20
+
21
+ A jupyter notebook.
22
+
23
+ For unconditional generation: `sample.ipynb`
24
+
25
+ For manipulation: `manipulate.ipynb`
26
+
27
+ Aligning your own images:
28
+
29
+ 1. Put images into the `imgs` directory
30
+ 2. Run `align.py` (need to `pip install dlib requests`)
31
+ 3. Result images will be available in `imgs_align` directory
32
+
33
+
34
+ <style type="text/css">
35
+ img {
36
+ height: 256px;
37
+ }
38
+ </style>
39
+
40
+ | ![](imgs/sandy.JPG) | ![](imgs_align/sandy.png) | ![](imgs_manipulated/sandy-wavyhair.png) |
41
+ |---|---|---|
42
+
43
+
44
+ ### Checkpoints
45
+
46
+ We provide checkpoints for the following models:
47
+
48
+ 1. DDIM: **FFHQ128** ([72M](https://drive.google.com/drive/folders/1-J8FPNZOQxSqpfTpwRXawLi2KKGL1qlK?usp=sharing), [130M](https://drive.google.com/drive/folders/17T5YJXpYdgE6cWltN8gZFxRsJzpVxnLh?usp=sharing)), [**Bedroom128**](https://drive.google.com/drive/folders/19s-lAiK7fGD5Meo5obNV5o0L3MfqU0Sk?usp=sharing), [**Horse128**](https://drive.google.com/drive/folders/1PiC5JWLcd8mZW9cghDCR0V4Hx0QCXOor?usp=sharing)
49
+ 2. DiffAE (autoencoding only): [**FFHQ256**](https://drive.google.com/drive/folders/1hTP9QbYXwv_Nl5sgcZNH0yKprJx7ivC5?usp=sharing), **FFHQ128** ([72M](https://drive.google.com/drive/folders/15QHmZP1G5jEMh80R1Nbtdb4ZKb6VvfII?usp=sharing), [130M](https://drive.google.com/drive/folders/1UlwLwgv16cEqxTn7g-V2ykIyopmY_fVz?usp=sharing)), [**Bedroom128**](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [**Horse128**](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing)
50
+ 3. DiffAE (with latent DPM, can sample): [**FFHQ256**](https://drive.google.com/drive/folders/1MonJKYwVLzvCFYuVhp-l9mChq5V2XI6w?usp=sharing), [**FFHQ128**](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing), [**Bedroom128**](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [**Horse128**](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing)
51
+ 4. DiffAE's classifiers (for manipulation): [**FFHQ256's latent on CelebAHQ**](https://drive.google.com/drive/folders/1QGkTfvNhgi_TbbV8GbX1Emrp0lStsqLj?usp=sharing), [**FFHQ128's latent on CelebAHQ**](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing)
52
+
53
+ Checkpoints ought to be put into a separate directory `checkpoints`.
54
+ Download the checkpoints and put them into `checkpoints` directory. It should look like this:
55
+
56
+ ```
57
+ checkpoints/
58
+ - bedroom128_autoenc
59
+ - last.ckpt # diffae checkpoint
60
+ - latent.ckpt # predicted z_sem on the dataset
61
+ - bedroom128_autoenc_latent
62
+ - last.ckpt # diffae + latent DPM checkpoint
63
+ - bedroom128_ddpm
64
+ - ...
65
+ ```
66
+
67
+
68
+ ### LMDB Datasets
69
+
70
+ We do not own any of the following datasets. We provide the LMDB ready-to-use dataset for the sake of convenience.
71
+
72
+ - [FFHQ](https://drive.google.com/drive/folders/1ww7itaSo53NDMa0q-wn-3HWZ3HHqK1IK?usp=sharing)
73
+ - [CelebAHQ](https://drive.google.com/drive/folders/1SX3JuVHjYA8sA28EGxr_IoHJ63s4Btbl?usp=sharing)
74
+ - [CelebA](https://drive.google.com/drive/folders/1HJAhK2hLYcT_n0gWlCu5XxdZj-bPekZ0?usp=sharing)
75
+ - [LSUN Bedroom](https://drive.google.com/drive/folders/1O_3aT3LtY1YDE2pOQCp6MFpCk7Pcpkhb?usp=sharing)
76
+ - [LSUN Horse](https://drive.google.com/drive/folders/1ooHW7VivZUs4i5CarPaWxakCwfeqAK8l?usp=sharing)
77
+
78
+ The directory tree should be:
79
+
80
+ ```
81
+ datasets/
82
+ - bedroom256.lmdb
83
+ - celebahq256.lmdb
84
+ - celeba.lmdb
85
+ - ffhq256.lmdb
86
+ - horse256.lmdb
87
+ ```
88
+
89
+ You can also download from the original sources, and use our provided codes to package them as LMDB files.
90
+ Original sources for each dataset is as follows:
91
+
92
+ - FFHQ (https://github.com/NVlabs/ffhq-dataset)
93
+ - CelebAHQ (https://github.com/switchablenorms/CelebAMask-HQ)
94
+ - CelebA (https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
95
+ - LSUN (https://github.com/fyu/lsun)
96
+
97
+ The conversion codes are provided as:
98
+
99
+ ```
100
+ data_resize_bedroom.py
101
+ data_resize_celebhq.py
102
+ data_resize_celeba.py
103
+ data_resize_ffhq.py
104
+ data_resize_horse.py
105
+ ```
106
+
107
+ Google drive: https://drive.google.com/drive/folders/1abNP4QKGbNnymjn8607BF0cwxX2L23jh?usp=sharing
108
+
109
+
110
+ ## Training
111
+
112
+ We provide scripts for training & evaluate DDIM and DiffAE (including latent DPM) on the following datasets: FFHQ128, FFHQ256, Bedroom128, Horse128, Celeba64 (D2C's crop).
113
+ Usually, the evaluation results (FID's) will be available in `eval` directory.
114
+
115
+ Note: Most experiment requires at least 4x V100s during training the DPM models while requiring 1x 2080Ti during training the accompanying latent DPM.
116
+
117
+
118
+
119
+ **FFHQ128**
120
+ ```
121
+ # diffae
122
+ python run_ffhq128.py
123
+ # ddim
124
+ python run_ffhq128_ddim.py
125
+ ```
126
+
127
+ A classifier (for manipulation) can be trained using:
128
+ ```
129
+ python run_ffhq128_cls.py
130
+ ```
131
+
132
+ **FFHQ256**
133
+
134
+ We only trained the DiffAE due to high computation cost.
135
+ This requires 8x V100s.
136
+ ```
137
+ sbatch run_ffhq256.py
138
+ ```
139
+
140
+ After the task is done, you need to train the latent DPM (requiring only 1x 2080Ti)
141
+ ```
142
+ python run_ffhq256_latent.py
143
+ ```
144
+
145
+ A classifier (for manipulation) can be trained using:
146
+ ```
147
+ python run_ffhq256_cls.py
148
+ ```
149
+
150
+ **Bedroom128**
151
+
152
+ ```
153
+ # diffae
154
+ python run_bedroom128.py
155
+ # ddim
156
+ python run_bedroom128_ddim.py
157
+ ```
158
+
159
+ **Horse128**
160
+
161
+ ```
162
+ # diffae
163
+ python run_horse128.py
164
+ # ddim
165
+ python run_horse128_ddim.py
166
+ ```
167
+
168
+ **Celeba64**
169
+
170
+ This experiment can be run on 2080Ti's.
171
+
172
+ ```
173
+ # diffae
174
+ python run_celeba64.py
175
+ ```
align.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bz2
2
+ import os
3
+ import os.path as osp
4
+ import sys
5
+ from multiprocessing import Pool
6
+
7
+ import dlib
8
+ import numpy as np
9
+ import PIL.Image
10
+ import requests
11
+ import scipy.ndimage
12
+ from tqdm import tqdm
13
+ from argparse import ArgumentParser
14
+
15
+ LANDMARKS_MODEL_URL = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2'
16
+
17
+
18
+ def image_align(src_file,
19
+ dst_file,
20
+ face_landmarks,
21
+ output_size=1024,
22
+ transform_size=4096,
23
+ enable_padding=True):
24
+ # Align function from FFHQ dataset pre-processing step
25
+ # https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
26
+
27
+ lm = np.array(face_landmarks)
28
+ lm_chin = lm[0:17] # left-right
29
+ lm_eyebrow_left = lm[17:22] # left-right
30
+ lm_eyebrow_right = lm[22:27] # left-right
31
+ lm_nose = lm[27:31] # top-down
32
+ lm_nostrils = lm[31:36] # top-down
33
+ lm_eye_left = lm[36:42] # left-clockwise
34
+ lm_eye_right = lm[42:48] # left-clockwise
35
+ lm_mouth_outer = lm[48:60] # left-clockwise
36
+ lm_mouth_inner = lm[60:68] # left-clockwise
37
+
38
+ # Calculate auxiliary vectors.
39
+ eye_left = np.mean(lm_eye_left, axis=0)
40
+ eye_right = np.mean(lm_eye_right, axis=0)
41
+ eye_avg = (eye_left + eye_right) * 0.5
42
+ eye_to_eye = eye_right - eye_left
43
+ mouth_left = lm_mouth_outer[0]
44
+ mouth_right = lm_mouth_outer[6]
45
+ mouth_avg = (mouth_left + mouth_right) * 0.5
46
+ eye_to_mouth = mouth_avg - eye_avg
47
+
48
+ # Choose oriented crop rectangle.
49
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
50
+ x /= np.hypot(*x)
51
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
52
+ y = np.flipud(x) * [-1, 1]
53
+ c = eye_avg + eye_to_mouth * 0.1
54
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
55
+ qsize = np.hypot(*x) * 2
56
+
57
+ # Load in-the-wild image.
58
+ if not os.path.isfile(src_file):
59
+ print(
60
+ '\nCannot find source image. Please run "--wilds" before "--align".'
61
+ )
62
+ return
63
+ img = PIL.Image.open(src_file)
64
+ img = img.convert('RGB')
65
+
66
+ # Shrink.
67
+ shrink = int(np.floor(qsize / output_size * 0.5))
68
+ if shrink > 1:
69
+ rsize = (int(np.rint(float(img.size[0]) / shrink)),
70
+ int(np.rint(float(img.size[1]) / shrink)))
71
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
72
+ quad /= shrink
73
+ qsize /= shrink
74
+
75
+ # Crop.
76
+ border = max(int(np.rint(qsize * 0.1)), 3)
77
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
78
+ int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
79
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),
80
+ min(crop[2] + border,
81
+ img.size[0]), min(crop[3] + border, img.size[1]))
82
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
83
+ img = img.crop(crop)
84
+ quad -= crop[0:2]
85
+
86
+ # Pad.
87
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
88
+ int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
89
+ pad = (max(-pad[0] + border,
90
+ 0), max(-pad[1] + border,
91
+ 0), max(pad[2] - img.size[0] + border,
92
+ 0), max(pad[3] - img.size[1] + border, 0))
93
+ if enable_padding and max(pad) > border - 4:
94
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
95
+ img = np.pad(np.float32(img),
96
+ ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
97
+ h, w, _ = img.shape
98
+ y, x, _ = np.ogrid[:h, :w, :1]
99
+ mask = np.maximum(
100
+ 1.0 -
101
+ np.minimum(np.float32(x) / pad[0],
102
+ np.float32(w - 1 - x) / pad[2]), 1.0 -
103
+ np.minimum(np.float32(y) / pad[1],
104
+ np.float32(h - 1 - y) / pad[3]))
105
+ blur = qsize * 0.02
106
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) -
107
+ img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
108
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
109
+ img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)),
110
+ 'RGB')
111
+ quad += pad[:2]
112
+
113
+ # Transform.
114
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD,
115
+ (quad + 0.5).flatten(), PIL.Image.BILINEAR)
116
+ if output_size < transform_size:
117
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
118
+
119
+ # Save aligned image.
120
+ img.save(dst_file, 'PNG')
121
+
122
+
123
+ class LandmarksDetector:
124
+ def __init__(self, predictor_model_path):
125
+ """
126
+ :param predictor_model_path: path to shape_predictor_68_face_landmarks.dat file
127
+ """
128
+ self.detector = dlib.get_frontal_face_detector(
129
+ ) # cnn_face_detection_model_v1 also can be used
130
+ self.shape_predictor = dlib.shape_predictor(predictor_model_path)
131
+
132
+ def get_landmarks(self, image):
133
+ img = dlib.load_rgb_image(image)
134
+ dets = self.detector(img, 1)
135
+
136
+ for detection in dets:
137
+ face_landmarks = [
138
+ (item.x, item.y)
139
+ for item in self.shape_predictor(img, detection).parts()
140
+ ]
141
+ yield face_landmarks
142
+
143
+
144
+ def unpack_bz2(src_path):
145
+ dst_path = src_path[:-4]
146
+ if os.path.exists(dst_path):
147
+ print('cached')
148
+ return dst_path
149
+ data = bz2.BZ2File(src_path).read()
150
+ with open(dst_path, 'wb') as fp:
151
+ fp.write(data)
152
+ return dst_path
153
+
154
+
155
+ def work_landmark(raw_img_path, img_name, face_landmarks):
156
+ face_img_name = '%s.png' % (os.path.splitext(img_name)[0], )
157
+ aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name)
158
+ if os.path.exists(aligned_face_path):
159
+ return
160
+ image_align(raw_img_path,
161
+ aligned_face_path,
162
+ face_landmarks,
163
+ output_size=256)
164
+
165
+
166
+ def get_file(src, tgt):
167
+ if os.path.exists(tgt):
168
+ print('cached')
169
+ return tgt
170
+ tgt_dir = os.path.dirname(tgt)
171
+ if not os.path.exists(tgt_dir):
172
+ os.makedirs(tgt_dir)
173
+ file = requests.get(src)
174
+ open(tgt, 'wb').write(file.content)
175
+ return tgt
176
+
177
+
178
+ if __name__ == "__main__":
179
+ """
180
+ Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step
181
+ python align_images.py /raw_images /aligned_images
182
+ """
183
+ parser = ArgumentParser()
184
+ parser.add_argument("-i",
185
+ "--input_imgs_path",
186
+ type=str,
187
+ default="imgs",
188
+ help="input images directory path")
189
+ parser.add_argument("-o",
190
+ "--output_imgs_path",
191
+ type=str,
192
+ default="imgs_align",
193
+ help="output images directory path")
194
+
195
+ args = parser.parse_args()
196
+
197
+ # takes very long time ...
198
+ landmarks_model_path = unpack_bz2(
199
+ get_file(
200
+ 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2',
201
+ 'temp/shape_predictor_68_face_landmarks.dat.bz2'))
202
+
203
+ # RAW_IMAGES_DIR = sys.argv[1]
204
+ # ALIGNED_IMAGES_DIR = sys.argv[2]
205
+ RAW_IMAGES_DIR = args.input_imgs_path
206
+ ALIGNED_IMAGES_DIR = args.output_imgs_path
207
+
208
+ if not osp.exists(ALIGNED_IMAGES_DIR): os.makedirs(ALIGNED_IMAGES_DIR)
209
+
210
+ files = os.listdir(RAW_IMAGES_DIR)
211
+ print(f'total img files {len(files)}')
212
+ with tqdm(total=len(files)) as progress:
213
+
214
+ def cb(*args):
215
+ # print('update')
216
+ progress.update()
217
+
218
+ def err_cb(e):
219
+ print('error:', e)
220
+
221
+ with Pool(8) as pool:
222
+ res = []
223
+ landmarks_detector = LandmarksDetector(landmarks_model_path)
224
+ for img_name in files:
225
+ raw_img_path = os.path.join(RAW_IMAGES_DIR, img_name)
226
+ # print('img_name:', img_name)
227
+ for i, face_landmarks in enumerate(
228
+ landmarks_detector.get_landmarks(raw_img_path),
229
+ start=1):
230
+ # assert i == 1, f'{i}'
231
+ # print(i, face_landmarks)
232
+ # face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], i)
233
+ # aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name)
234
+ # image_align(raw_img_path, aligned_face_path, face_landmarks, output_size=256)
235
+
236
+ work_landmark(raw_img_path, img_name, face_landmarks)
237
+ progress.update()
238
+
239
+ # job = pool.apply_async(
240
+ # work_landmark,
241
+ # (raw_img_path, img_name, face_landmarks),
242
+ # callback=cb,
243
+ # error_callback=err_cb,
244
+ # )
245
+ # res.append(job)
246
+
247
+ # pool.close()
248
+ # pool.join()
249
+ print(f"output aligned images at: {ALIGNED_IMAGES_DIR}")
autoencoding.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
choices.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from torch import nn
3
+
4
+
5
+ class TrainMode(Enum):
6
+ # manipulate mode = training the classifier
7
+ manipulate = 'manipulate'
8
+ # default trainin mode!
9
+ diffusion = 'diffusion'
10
+ # default latent training mode!
11
+ # fitting the a DDPM to a given latent
12
+ latent_diffusion = 'latentdiffusion'
13
+
14
+ def is_manipulate(self):
15
+ return self in [
16
+ TrainMode.manipulate,
17
+ ]
18
+
19
+ def is_diffusion(self):
20
+ return self in [
21
+ TrainMode.diffusion,
22
+ TrainMode.latent_diffusion,
23
+ ]
24
+
25
+ def is_autoenc(self):
26
+ # the network possibly does autoencoding
27
+ return self in [
28
+ TrainMode.diffusion,
29
+ ]
30
+
31
+ def is_latent_diffusion(self):
32
+ return self in [
33
+ TrainMode.latent_diffusion,
34
+ ]
35
+
36
+ def use_latent_net(self):
37
+ return self.is_latent_diffusion()
38
+
39
+ def require_dataset_infer(self):
40
+ """
41
+ whether training in this mode requires the latent variables to be available?
42
+ """
43
+ # this will precalculate all the latents before hand
44
+ # and the dataset will be all the predicted latents
45
+ return self in [
46
+ TrainMode.latent_diffusion,
47
+ TrainMode.manipulate,
48
+ ]
49
+
50
+
51
+ class ManipulateMode(Enum):
52
+ """
53
+ how to train the classifier to manipulate
54
+ """
55
+ # train on whole celeba attr dataset
56
+ celebahq_all = 'celebahq_all'
57
+ # celeba with D2C's crop
58
+ d2c_fewshot = 'd2cfewshot'
59
+ d2c_fewshot_allneg = 'd2cfewshotallneg'
60
+
61
+ def is_celeba_attr(self):
62
+ return self in [
63
+ ManipulateMode.d2c_fewshot,
64
+ ManipulateMode.d2c_fewshot_allneg,
65
+ ManipulateMode.celebahq_all,
66
+ ]
67
+
68
+ def is_single_class(self):
69
+ return self in [
70
+ ManipulateMode.d2c_fewshot,
71
+ ManipulateMode.d2c_fewshot_allneg,
72
+ ]
73
+
74
+ def is_fewshot(self):
75
+ return self in [
76
+ ManipulateMode.d2c_fewshot,
77
+ ManipulateMode.d2c_fewshot_allneg,
78
+ ]
79
+
80
+ def is_fewshot_allneg(self):
81
+ return self in [
82
+ ManipulateMode.d2c_fewshot_allneg,
83
+ ]
84
+
85
+
86
+ class ModelType(Enum):
87
+ """
88
+ Kinds of the backbone models
89
+ """
90
+
91
+ # unconditional ddpm
92
+ ddpm = 'ddpm'
93
+ # autoencoding ddpm cannot do unconditional generation
94
+ autoencoder = 'autoencoder'
95
+
96
+ def has_autoenc(self):
97
+ return self in [
98
+ ModelType.autoencoder,
99
+ ]
100
+
101
+ def can_sample(self):
102
+ return self in [ModelType.ddpm]
103
+
104
+
105
+ class ModelName(Enum):
106
+ """
107
+ List of all supported model classes
108
+ """
109
+
110
+ beatgans_ddpm = 'beatgans_ddpm'
111
+ beatgans_autoenc = 'beatgans_autoenc'
112
+
113
+
114
+ class ModelMeanType(Enum):
115
+ """
116
+ Which type of output the model predicts.
117
+ """
118
+
119
+ eps = 'eps' # the model predicts epsilon
120
+
121
+
122
+ class ModelVarType(Enum):
123
+ """
124
+ What is used as the model's output variance.
125
+
126
+ The LEARNED_RANGE option has been added to allow the model to predict
127
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
128
+ """
129
+
130
+ # posterior beta_t
131
+ fixed_small = 'fixed_small'
132
+ # beta_t
133
+ fixed_large = 'fixed_large'
134
+
135
+
136
+ class LossType(Enum):
137
+ mse = 'mse' # use raw MSE loss (and KL when learning variances)
138
+ l1 = 'l1'
139
+
140
+
141
+ class GenerativeType(Enum):
142
+ """
143
+ How's a sample generated
144
+ """
145
+
146
+ ddpm = 'ddpm'
147
+ ddim = 'ddim'
148
+
149
+
150
+ class OptimizerType(Enum):
151
+ adam = 'adam'
152
+ adamw = 'adamw'
153
+
154
+
155
+ class Activation(Enum):
156
+ none = 'none'
157
+ relu = 'relu'
158
+ lrelu = 'lrelu'
159
+ silu = 'silu'
160
+ tanh = 'tanh'
161
+
162
+ def get_act(self):
163
+ if self == Activation.none:
164
+ return nn.Identity()
165
+ elif self == Activation.relu:
166
+ return nn.ReLU()
167
+ elif self == Activation.lrelu:
168
+ return nn.LeakyReLU(negative_slope=0.2)
169
+ elif self == Activation.silu:
170
+ return nn.SiLU()
171
+ elif self == Activation.tanh:
172
+ return nn.Tanh()
173
+ else:
174
+ raise NotImplementedError()
175
+
176
+
177
+ class ManipulateLossType(Enum):
178
+ bce = 'bce'
179
+ mse = 'mse'
cog.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ cuda: "10.2"
3
+ gpu: true
4
+ python_version: "3.8"
5
+ system_packages:
6
+ - "libgl1-mesa-glx"
7
+ - "libglib2.0-0"
8
+ python_packages:
9
+ - "numpy==1.21.5"
10
+ - "cmake==3.23.3"
11
+ - "ipython==7.21.0"
12
+ - "opencv-python==4.5.4.58"
13
+ - "pandas==1.1.5"
14
+ - "lmdb==1.2.1"
15
+ - "lpips==0.1.4"
16
+ - "pytorch-fid==0.2.0"
17
+ - "ftfy==6.1.1"
18
+ - "scipy==1.5.4"
19
+ - "torch==1.9.1"
20
+ - "torchvision==0.10.1"
21
+ - "tqdm==4.62.3"
22
+ - "regex==2022.7.25"
23
+ - "Pillow==9.2.0"
24
+ - "pytorch_lightning==1.7.0"
25
+
26
+ run:
27
+ - pip install dlib
28
+
29
+ predict: "predict.py:Predictor"
config.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model.unet import ScaleAt
2
+ from model.latentnet import *
3
+ from diffusion.resample import UniformSampler
4
+ from diffusion.diffusion import space_timesteps
5
+ from typing import Tuple
6
+
7
+ from torch.utils.data import DataLoader
8
+
9
+ from config_base import BaseConfig
10
+ from dataset import *
11
+ from diffusion import *
12
+ from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule
13
+ from model import *
14
+ from choices import *
15
+ from multiprocessing import get_context
16
+ import os
17
+ from dataset_util import *
18
+ from torch.utils.data.distributed import DistributedSampler
19
+
20
+ data_paths = {
21
+ 'ffhqlmdb256':
22
+ os.path.expanduser('datasets/ffhq256.lmdb'),
23
+ # used for training a classifier
24
+ 'celeba':
25
+ os.path.expanduser('datasets/celeba'),
26
+ # used for training DPM models
27
+ 'celebalmdb':
28
+ os.path.expanduser('datasets/celeba.lmdb'),
29
+ 'celebahq':
30
+ os.path.expanduser('datasets/celebahq256.lmdb'),
31
+ 'horse256':
32
+ os.path.expanduser('datasets/horse256.lmdb'),
33
+ 'bedroom256':
34
+ os.path.expanduser('datasets/bedroom256.lmdb'),
35
+ 'celeba_anno':
36
+ os.path.expanduser('datasets/celeba_anno/list_attr_celeba.txt'),
37
+ 'celebahq_anno':
38
+ os.path.expanduser(
39
+ 'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'),
40
+ 'celeba_relight':
41
+ os.path.expanduser('datasets/celeba_hq_light/celeba_light.txt'),
42
+ }
43
+
44
+
45
+ @dataclass
46
+ class PretrainConfig(BaseConfig):
47
+ name: str
48
+ path: str
49
+
50
+
51
+ @dataclass
52
+ class TrainConfig(BaseConfig):
53
+ # random seed
54
+ seed: int = 0
55
+ train_mode: TrainMode = TrainMode.diffusion
56
+ train_cond0_prob: float = 0
57
+ train_pred_xstart_detach: bool = True
58
+ train_interpolate_prob: float = 0
59
+ train_interpolate_img: bool = False
60
+ manipulate_mode: ManipulateMode = ManipulateMode.celebahq_all
61
+ manipulate_cls: str = None
62
+ manipulate_shots: int = None
63
+ manipulate_loss: ManipulateLossType = ManipulateLossType.bce
64
+ manipulate_znormalize: bool = False
65
+ manipulate_seed: int = 0
66
+ accum_batches: int = 1
67
+ autoenc_mid_attn: bool = True
68
+ batch_size: int = 16
69
+ batch_size_eval: int = None
70
+ beatgans_gen_type: GenerativeType = GenerativeType.ddim
71
+ beatgans_loss_type: LossType = LossType.mse
72
+ beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps
73
+ beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large
74
+ beatgans_rescale_timesteps: bool = False
75
+ latent_infer_path: str = None
76
+ latent_znormalize: bool = False
77
+ latent_gen_type: GenerativeType = GenerativeType.ddim
78
+ latent_loss_type: LossType = LossType.mse
79
+ latent_model_mean_type: ModelMeanType = ModelMeanType.eps
80
+ latent_model_var_type: ModelVarType = ModelVarType.fixed_large
81
+ latent_rescale_timesteps: bool = False
82
+ latent_T_eval: int = 1_000
83
+ latent_clip_sample: bool = False
84
+ latent_beta_scheduler: str = 'linear'
85
+ beta_scheduler: str = 'linear'
86
+ data_name: str = ''
87
+ data_val_name: str = None
88
+ diffusion_type: str = None
89
+ dropout: float = 0.1
90
+ ema_decay: float = 0.9999
91
+ eval_num_images: int = 5_000
92
+ eval_every_samples: int = 200_000
93
+ eval_ema_every_samples: int = 200_000
94
+ fid_use_torch: bool = True
95
+ fp16: bool = False
96
+ grad_clip: float = 1
97
+ img_size: int = 64
98
+ lr: float = 0.0001
99
+ optimizer: OptimizerType = OptimizerType.adam
100
+ weight_decay: float = 0
101
+ model_conf: ModelConfig = None
102
+ model_name: ModelName = None
103
+ model_type: ModelType = None
104
+ net_attn: Tuple[int] = None
105
+ net_beatgans_attn_head: int = 1
106
+ # not necessarily the same as the the number of style channels
107
+ net_beatgans_embed_channels: int = 512
108
+ net_resblock_updown: bool = True
109
+ net_enc_use_time: bool = False
110
+ net_enc_pool: str = 'adaptivenonzero'
111
+ net_beatgans_gradient_checkpoint: bool = False
112
+ net_beatgans_resnet_two_cond: bool = False
113
+ net_beatgans_resnet_use_zero_module: bool = True
114
+ net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm
115
+ net_beatgans_resnet_cond_channels: int = None
116
+ net_ch_mult: Tuple[int] = None
117
+ net_ch: int = 64
118
+ net_enc_attn: Tuple[int] = None
119
+ net_enc_k: int = None
120
+ # number of resblocks for the encoder (half-unet)
121
+ net_enc_num_res_blocks: int = 2
122
+ net_enc_channel_mult: Tuple[int] = None
123
+ net_enc_grad_checkpoint: bool = False
124
+ net_autoenc_stochastic: bool = False
125
+ net_latent_activation: Activation = Activation.silu
126
+ net_latent_channel_mult: Tuple[int] = (1, 2, 4)
127
+ net_latent_condition_bias: float = 0
128
+ net_latent_dropout: float = 0
129
+ net_latent_layers: int = None
130
+ net_latent_net_last_act: Activation = Activation.none
131
+ net_latent_net_type: LatentNetType = LatentNetType.none
132
+ net_latent_num_hid_channels: int = 1024
133
+ net_latent_num_time_layers: int = 2
134
+ net_latent_skip_layers: Tuple[int] = None
135
+ net_latent_time_emb_channels: int = 64
136
+ net_latent_use_norm: bool = False
137
+ net_latent_time_last_act: bool = False
138
+ net_num_res_blocks: int = 2
139
+ # number of resblocks for the UNET
140
+ net_num_input_res_blocks: int = None
141
+ net_enc_num_cls: int = None
142
+ num_workers: int = 4
143
+ parallel: bool = False
144
+ postfix: str = ''
145
+ sample_size: int = 64
146
+ sample_every_samples: int = 20_000
147
+ save_every_samples: int = 100_000
148
+ style_ch: int = 512
149
+ T_eval: int = 1_000
150
+ T_sampler: str = 'uniform'
151
+ T: int = 1_000
152
+ total_samples: int = 10_000_000
153
+ warmup: int = 0
154
+ pretrain: PretrainConfig = None
155
+ continue_from: PretrainConfig = None
156
+ eval_programs: Tuple[str] = None
157
+ # if present load the checkpoint from this path instead
158
+ eval_path: str = None
159
+ base_dir: str = 'checkpoints'
160
+ use_cache_dataset: bool = False
161
+ data_cache_dir: str = os.path.expanduser('~/cache')
162
+ work_cache_dir: str = os.path.expanduser('~/mycache')
163
+ # to be overridden
164
+ name: str = ''
165
+
166
+ def __post_init__(self):
167
+ self.batch_size_eval = self.batch_size_eval or self.batch_size
168
+ self.data_val_name = self.data_val_name or self.data_name
169
+
170
+ def scale_up_gpus(self, num_gpus, num_nodes=1):
171
+ self.eval_ema_every_samples *= num_gpus * num_nodes
172
+ self.eval_every_samples *= num_gpus * num_nodes
173
+ self.sample_every_samples *= num_gpus * num_nodes
174
+ self.batch_size *= num_gpus * num_nodes
175
+ self.batch_size_eval *= num_gpus * num_nodes
176
+ return self
177
+
178
+ @property
179
+ def batch_size_effective(self):
180
+ return self.batch_size * self.accum_batches
181
+
182
+ @property
183
+ def fid_cache(self):
184
+ # we try to use the local dirs to reduce the load over network drives
185
+ # hopefully, this would reduce the disconnection problems with sshfs
186
+ return f'{self.work_cache_dir}/eval_images/{self.data_name}_size{self.img_size}_{self.eval_num_images}'
187
+
188
+ @property
189
+ def data_path(self):
190
+ # may use the cache dir
191
+ path = data_paths[self.data_name]
192
+ if self.use_cache_dataset and path is not None:
193
+ path = use_cached_dataset_path(
194
+ path, f'{self.data_cache_dir}/{self.data_name}')
195
+ return path
196
+
197
+ @property
198
+ def logdir(self):
199
+ return f'{self.base_dir}/{self.name}'
200
+
201
+ @property
202
+ def generate_dir(self):
203
+ # we try to use the local dirs to reduce the load over network drives
204
+ # hopefully, this would reduce the disconnection problems with sshfs
205
+ return f'{self.work_cache_dir}/gen_images/{self.name}'
206
+
207
+ def _make_diffusion_conf(self, T=None):
208
+ if self.diffusion_type == 'beatgans':
209
+ # can use T < self.T for evaluation
210
+ # follows the guided-diffusion repo conventions
211
+ # t's are evenly spaced
212
+ if self.beatgans_gen_type == GenerativeType.ddpm:
213
+ section_counts = [T]
214
+ elif self.beatgans_gen_type == GenerativeType.ddim:
215
+ section_counts = f'ddim{T}'
216
+ else:
217
+ raise NotImplementedError()
218
+
219
+ return SpacedDiffusionBeatGansConfig(
220
+ gen_type=self.beatgans_gen_type,
221
+ model_type=self.model_type,
222
+ betas=get_named_beta_schedule(self.beta_scheduler, self.T),
223
+ model_mean_type=self.beatgans_model_mean_type,
224
+ model_var_type=self.beatgans_model_var_type,
225
+ loss_type=self.beatgans_loss_type,
226
+ rescale_timesteps=self.beatgans_rescale_timesteps,
227
+ use_timesteps=space_timesteps(num_timesteps=self.T,
228
+ section_counts=section_counts),
229
+ fp16=self.fp16,
230
+ )
231
+ else:
232
+ raise NotImplementedError()
233
+
234
+ def _make_latent_diffusion_conf(self, T=None):
235
+ # can use T < self.T for evaluation
236
+ # follows the guided-diffusion repo conventions
237
+ # t's are evenly spaced
238
+ if self.latent_gen_type == GenerativeType.ddpm:
239
+ section_counts = [T]
240
+ elif self.latent_gen_type == GenerativeType.ddim:
241
+ section_counts = f'ddim{T}'
242
+ else:
243
+ raise NotImplementedError()
244
+
245
+ return SpacedDiffusionBeatGansConfig(
246
+ train_pred_xstart_detach=self.train_pred_xstart_detach,
247
+ gen_type=self.latent_gen_type,
248
+ # latent's model is always ddpm
249
+ model_type=ModelType.ddpm,
250
+ # latent shares the beta scheduler and full T
251
+ betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T),
252
+ model_mean_type=self.latent_model_mean_type,
253
+ model_var_type=self.latent_model_var_type,
254
+ loss_type=self.latent_loss_type,
255
+ rescale_timesteps=self.latent_rescale_timesteps,
256
+ use_timesteps=space_timesteps(num_timesteps=self.T,
257
+ section_counts=section_counts),
258
+ fp16=self.fp16,
259
+ )
260
+
261
+ @property
262
+ def model_out_channels(self):
263
+ return 3
264
+
265
+ def make_T_sampler(self):
266
+ if self.T_sampler == 'uniform':
267
+ return UniformSampler(self.T)
268
+ else:
269
+ raise NotImplementedError()
270
+
271
+ def make_diffusion_conf(self):
272
+ return self._make_diffusion_conf(self.T)
273
+
274
+ def make_eval_diffusion_conf(self):
275
+ return self._make_diffusion_conf(T=self.T_eval)
276
+
277
+ def make_latent_diffusion_conf(self):
278
+ return self._make_latent_diffusion_conf(T=self.T)
279
+
280
+ def make_latent_eval_diffusion_conf(self):
281
+ # latent can have different eval T
282
+ return self._make_latent_diffusion_conf(T=self.latent_T_eval)
283
+
284
+ def make_dataset(self, path=None, **kwargs):
285
+ if self.data_name == 'ffhqlmdb256':
286
+ return FFHQlmdb(path=path or self.data_path,
287
+ image_size=self.img_size,
288
+ **kwargs)
289
+ elif self.data_name == 'horse256':
290
+ return Horse_lmdb(path=path or self.data_path,
291
+ image_size=self.img_size,
292
+ **kwargs)
293
+ elif self.data_name == 'bedroom256':
294
+ return Horse_lmdb(path=path or self.data_path,
295
+ image_size=self.img_size,
296
+ **kwargs)
297
+ elif self.data_name == 'celebalmdb':
298
+ # always use d2c crop
299
+ return CelebAlmdb(path=path or self.data_path,
300
+ image_size=self.img_size,
301
+ original_resolution=None,
302
+ crop_d2c=True,
303
+ **kwargs)
304
+ else:
305
+ raise NotImplementedError()
306
+
307
+ def make_loader(self,
308
+ dataset,
309
+ shuffle: bool,
310
+ num_worker: bool = None,
311
+ drop_last: bool = True,
312
+ batch_size: int = None,
313
+ parallel: bool = False):
314
+ if parallel and distributed.is_initialized():
315
+ # drop last to make sure that there is no added special indexes
316
+ sampler = DistributedSampler(dataset,
317
+ shuffle=shuffle,
318
+ drop_last=True)
319
+ else:
320
+ sampler = None
321
+ return DataLoader(
322
+ dataset,
323
+ batch_size=batch_size or self.batch_size,
324
+ sampler=sampler,
325
+ # with sampler, use the sample instead of this option
326
+ shuffle=False if sampler else shuffle,
327
+ num_workers=num_worker or self.num_workers,
328
+ pin_memory=True,
329
+ drop_last=drop_last,
330
+ multiprocessing_context=get_context('fork'),
331
+ )
332
+
333
+ def make_model_conf(self):
334
+ if self.model_name == ModelName.beatgans_ddpm:
335
+ self.model_type = ModelType.ddpm
336
+ self.model_conf = BeatGANsUNetConfig(
337
+ attention_resolutions=self.net_attn,
338
+ channel_mult=self.net_ch_mult,
339
+ conv_resample=True,
340
+ dims=2,
341
+ dropout=self.dropout,
342
+ embed_channels=self.net_beatgans_embed_channels,
343
+ image_size=self.img_size,
344
+ in_channels=3,
345
+ model_channels=self.net_ch,
346
+ num_classes=None,
347
+ num_head_channels=-1,
348
+ num_heads_upsample=-1,
349
+ num_heads=self.net_beatgans_attn_head,
350
+ num_res_blocks=self.net_num_res_blocks,
351
+ num_input_res_blocks=self.net_num_input_res_blocks,
352
+ out_channels=self.model_out_channels,
353
+ resblock_updown=self.net_resblock_updown,
354
+ use_checkpoint=self.net_beatgans_gradient_checkpoint,
355
+ use_new_attention_order=False,
356
+ resnet_two_cond=self.net_beatgans_resnet_two_cond,
357
+ resnet_use_zero_module=self.
358
+ net_beatgans_resnet_use_zero_module,
359
+ )
360
+ elif self.model_name in [
361
+ ModelName.beatgans_autoenc,
362
+ ]:
363
+ cls = BeatGANsAutoencConfig
364
+ # supports both autoenc and vaeddpm
365
+ if self.model_name == ModelName.beatgans_autoenc:
366
+ self.model_type = ModelType.autoencoder
367
+ else:
368
+ raise NotImplementedError()
369
+
370
+ if self.net_latent_net_type == LatentNetType.none:
371
+ latent_net_conf = None
372
+ elif self.net_latent_net_type == LatentNetType.skip:
373
+ latent_net_conf = MLPSkipNetConfig(
374
+ num_channels=self.style_ch,
375
+ skip_layers=self.net_latent_skip_layers,
376
+ num_hid_channels=self.net_latent_num_hid_channels,
377
+ num_layers=self.net_latent_layers,
378
+ num_time_emb_channels=self.net_latent_time_emb_channels,
379
+ activation=self.net_latent_activation,
380
+ use_norm=self.net_latent_use_norm,
381
+ condition_bias=self.net_latent_condition_bias,
382
+ dropout=self.net_latent_dropout,
383
+ last_act=self.net_latent_net_last_act,
384
+ num_time_layers=self.net_latent_num_time_layers,
385
+ time_last_act=self.net_latent_time_last_act,
386
+ )
387
+ else:
388
+ raise NotImplementedError()
389
+
390
+ self.model_conf = cls(
391
+ attention_resolutions=self.net_attn,
392
+ channel_mult=self.net_ch_mult,
393
+ conv_resample=True,
394
+ dims=2,
395
+ dropout=self.dropout,
396
+ embed_channels=self.net_beatgans_embed_channels,
397
+ enc_out_channels=self.style_ch,
398
+ enc_pool=self.net_enc_pool,
399
+ enc_num_res_block=self.net_enc_num_res_blocks,
400
+ enc_channel_mult=self.net_enc_channel_mult,
401
+ enc_grad_checkpoint=self.net_enc_grad_checkpoint,
402
+ enc_attn_resolutions=self.net_enc_attn,
403
+ image_size=self.img_size,
404
+ in_channels=3,
405
+ model_channels=self.net_ch,
406
+ num_classes=None,
407
+ num_head_channels=-1,
408
+ num_heads_upsample=-1,
409
+ num_heads=self.net_beatgans_attn_head,
410
+ num_res_blocks=self.net_num_res_blocks,
411
+ num_input_res_blocks=self.net_num_input_res_blocks,
412
+ out_channels=self.model_out_channels,
413
+ resblock_updown=self.net_resblock_updown,
414
+ use_checkpoint=self.net_beatgans_gradient_checkpoint,
415
+ use_new_attention_order=False,
416
+ resnet_two_cond=self.net_beatgans_resnet_two_cond,
417
+ resnet_use_zero_module=self.
418
+ net_beatgans_resnet_use_zero_module,
419
+ latent_net_conf=latent_net_conf,
420
+ resnet_cond_channels=self.net_beatgans_resnet_cond_channels,
421
+ )
422
+ else:
423
+ raise NotImplementedError(self.model_name)
424
+
425
+ return self.model_conf
config_base.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from copy import deepcopy
4
+ from dataclasses import dataclass
5
+
6
+
7
+ @dataclass
8
+ class BaseConfig:
9
+ def clone(self):
10
+ return deepcopy(self)
11
+
12
+ def inherit(self, another):
13
+ """inherit common keys from a given config"""
14
+ common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys())
15
+ for k in common_keys:
16
+ setattr(self, k, getattr(another, k))
17
+
18
+ def propagate(self):
19
+ """push down the configuration to all members"""
20
+ for k, v in self.__dict__.items():
21
+ if isinstance(v, BaseConfig):
22
+ v.inherit(self)
23
+ v.propagate()
24
+
25
+ def save(self, save_path):
26
+ """save config to json file"""
27
+ dirname = os.path.dirname(save_path)
28
+ if not os.path.exists(dirname):
29
+ os.makedirs(dirname)
30
+ conf = self.as_dict_jsonable()
31
+ with open(save_path, 'w') as f:
32
+ json.dump(conf, f)
33
+
34
+ def load(self, load_path):
35
+ """load json config"""
36
+ with open(load_path) as f:
37
+ conf = json.load(f)
38
+ self.from_dict(conf)
39
+
40
+ def from_dict(self, dict, strict=False):
41
+ for k, v in dict.items():
42
+ if not hasattr(self, k):
43
+ if strict:
44
+ raise ValueError(f"loading extra '{k}'")
45
+ else:
46
+ print(f"loading extra '{k}'")
47
+ continue
48
+ if isinstance(self.__dict__[k], BaseConfig):
49
+ self.__dict__[k].from_dict(v)
50
+ else:
51
+ self.__dict__[k] = v
52
+
53
+ def as_dict_jsonable(self):
54
+ conf = {}
55
+ for k, v in self.__dict__.items():
56
+ if isinstance(v, BaseConfig):
57
+ conf[k] = v.as_dict_jsonable()
58
+ else:
59
+ if jsonable(v):
60
+ conf[k] = v
61
+ else:
62
+ # ignore not jsonable
63
+ pass
64
+ return conf
65
+
66
+
67
+ def jsonable(x):
68
+ try:
69
+ json.dumps(x)
70
+ return True
71
+ except TypeError:
72
+ return False
data_resize_bedroom.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import multiprocessing
3
+ import os
4
+ from os.path import join, exists
5
+ from functools import partial
6
+ from io import BytesIO
7
+ import shutil
8
+
9
+ import lmdb
10
+ from PIL import Image
11
+ from torchvision.datasets import LSUNClass
12
+ from torchvision.transforms import functional as trans_fn
13
+ from tqdm import tqdm
14
+
15
+ from multiprocessing import Process, Queue
16
+
17
+
18
+ def resize_and_convert(img, size, resample, quality=100):
19
+ img = trans_fn.resize(img, size, resample)
20
+ img = trans_fn.center_crop(img, size)
21
+ buffer = BytesIO()
22
+ img.save(buffer, format="webp", quality=quality)
23
+ val = buffer.getvalue()
24
+
25
+ return val
26
+
27
+
28
+ def resize_multiple(img,
29
+ sizes=(128, 256, 512, 1024),
30
+ resample=Image.LANCZOS,
31
+ quality=100):
32
+ imgs = []
33
+
34
+ for size in sizes:
35
+ imgs.append(resize_and_convert(img, size, resample, quality))
36
+
37
+ return imgs
38
+
39
+
40
+ def resize_worker(idx, img, sizes, resample):
41
+ img = img.convert("RGB")
42
+ out = resize_multiple(img, sizes=sizes, resample=resample)
43
+ return idx, out
44
+
45
+
46
+ from torch.utils.data import Dataset, DataLoader
47
+
48
+
49
+ class ConvertDataset(Dataset):
50
+ def __init__(self, data) -> None:
51
+ self.data = data
52
+
53
+ def __len__(self):
54
+ return len(self.data)
55
+
56
+ def __getitem__(self, index):
57
+ img, _ = self.data[index]
58
+ bytes = resize_and_convert(img, 256, Image.LANCZOS, quality=90)
59
+ return bytes
60
+
61
+
62
+ if __name__ == "__main__":
63
+ """
64
+ converting lsun' original lmdb to our lmdb, which is somehow more performant.
65
+ """
66
+ from tqdm import tqdm
67
+
68
+ # path to the original lsun's lmdb
69
+ src_path = 'datasets/bedroom_train_lmdb'
70
+ out_path = 'datasets/bedroom256.lmdb'
71
+
72
+ dataset = LSUNClass(root=os.path.expanduser(src_path))
73
+ dataset = ConvertDataset(dataset)
74
+ loader = DataLoader(dataset,
75
+ batch_size=50,
76
+ num_workers=12,
77
+ collate_fn=lambda x: x,
78
+ shuffle=False)
79
+
80
+ target = os.path.expanduser(out_path)
81
+ if os.path.exists(target):
82
+ shutil.rmtree(target)
83
+
84
+ with lmdb.open(target, map_size=1024**4, readahead=False) as env:
85
+ with tqdm(total=len(dataset)) as progress:
86
+ i = 0
87
+ for batch in loader:
88
+ with env.begin(write=True) as txn:
89
+ for img in batch:
90
+ key = f"{256}-{str(i).zfill(7)}".encode("utf-8")
91
+ # print(key)
92
+ txn.put(key, img)
93
+ i += 1
94
+ progress.update()
95
+ # if i == 1000:
96
+ # break
97
+ # if total == len(imgset):
98
+ # break
99
+
100
+ with env.begin(write=True) as txn:
101
+ txn.put("length".encode("utf-8"), str(i).encode("utf-8"))
data_resize_celeba.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import multiprocessing
3
+ import os
4
+ import shutil
5
+ from functools import partial
6
+ from io import BytesIO
7
+ from multiprocessing import Process, Queue
8
+ from os.path import exists, join
9
+ from pathlib import Path
10
+
11
+ import lmdb
12
+ from PIL import Image
13
+ from torch.utils.data import DataLoader, Dataset
14
+ from torchvision.datasets import LSUNClass
15
+ from torchvision.transforms import functional as trans_fn
16
+ from tqdm import tqdm
17
+
18
+
19
+ def resize_and_convert(img, size, resample, quality=100):
20
+ if size is not None:
21
+ img = trans_fn.resize(img, size, resample)
22
+ img = trans_fn.center_crop(img, size)
23
+
24
+ buffer = BytesIO()
25
+ img.save(buffer, format="webp", quality=quality)
26
+ val = buffer.getvalue()
27
+
28
+ return val
29
+
30
+
31
+ def resize_multiple(img,
32
+ sizes=(128, 256, 512, 1024),
33
+ resample=Image.LANCZOS,
34
+ quality=100):
35
+ imgs = []
36
+
37
+ for size in sizes:
38
+ imgs.append(resize_and_convert(img, size, resample, quality))
39
+
40
+ return imgs
41
+
42
+
43
+ def resize_worker(idx, img, sizes, resample):
44
+ img = img.convert("RGB")
45
+ out = resize_multiple(img, sizes=sizes, resample=resample)
46
+ return idx, out
47
+
48
+
49
+ class ConvertDataset(Dataset):
50
+ def __init__(self, data, size) -> None:
51
+ self.data = data
52
+ self.size = size
53
+
54
+ def __len__(self):
55
+ return len(self.data)
56
+
57
+ def __getitem__(self, index):
58
+ img = self.data[index]
59
+ bytes = resize_and_convert(img, self.size, Image.LANCZOS, quality=100)
60
+ return bytes
61
+
62
+
63
+ class ImageFolder(Dataset):
64
+ def __init__(self, folder, ext='jpg'):
65
+ super().__init__()
66
+ paths = sorted([p for p in Path(f'{folder}').glob(f'*.{ext}')])
67
+ self.paths = paths
68
+
69
+ def __len__(self):
70
+ return len(self.paths)
71
+
72
+ def __getitem__(self, index):
73
+ path = os.path.join(self.paths[index])
74
+ img = Image.open(path)
75
+ return img
76
+
77
+
78
+ if __name__ == "__main__":
79
+ from tqdm import tqdm
80
+
81
+ out_path = 'datasets/celeba.lmdb'
82
+ in_path = 'datasets/celeba'
83
+ ext = 'jpg'
84
+ size = None
85
+
86
+ dataset = ImageFolder(in_path, ext)
87
+ print('len:', len(dataset))
88
+ dataset = ConvertDataset(dataset, size)
89
+ loader = DataLoader(dataset,
90
+ batch_size=50,
91
+ num_workers=12,
92
+ collate_fn=lambda x: x,
93
+ shuffle=False)
94
+
95
+ target = os.path.expanduser(out_path)
96
+ if os.path.exists(target):
97
+ shutil.rmtree(target)
98
+
99
+ with lmdb.open(target, map_size=1024**4, readahead=False) as env:
100
+ with tqdm(total=len(dataset)) as progress:
101
+ i = 0
102
+ for batch in loader:
103
+ with env.begin(write=True) as txn:
104
+ for img in batch:
105
+ key = f"{size}-{str(i).zfill(7)}".encode("utf-8")
106
+ # print(key)
107
+ txn.put(key, img)
108
+ i += 1
109
+ progress.update()
110
+ # if i == 1000:
111
+ # break
112
+ # if total == len(imgset):
113
+ # break
114
+
115
+ with env.begin(write=True) as txn:
116
+ txn.put("length".encode("utf-8"), str(i).encode("utf-8"))
data_resize_celebahq.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import multiprocessing
3
+ from functools import partial
4
+ from io import BytesIO
5
+ from pathlib import Path
6
+
7
+ import lmdb
8
+ from PIL import Image
9
+ from torch.utils.data import Dataset
10
+ from torchvision.transforms import functional as trans_fn
11
+ from tqdm import tqdm
12
+ import os
13
+
14
+
15
+ def resize_and_convert(img, size, resample, quality=100):
16
+ img = trans_fn.resize(img, size, resample)
17
+ img = trans_fn.center_crop(img, size)
18
+ buffer = BytesIO()
19
+ img.save(buffer, format="jpeg", quality=quality)
20
+ val = buffer.getvalue()
21
+
22
+ return val
23
+
24
+
25
+ def resize_multiple(img,
26
+ sizes=(128, 256, 512, 1024),
27
+ resample=Image.LANCZOS,
28
+ quality=100):
29
+ imgs = []
30
+
31
+ for size in sizes:
32
+ imgs.append(resize_and_convert(img, size, resample, quality))
33
+
34
+ return imgs
35
+
36
+
37
+ def resize_worker(img_file, sizes, resample):
38
+ i, (file, idx) = img_file
39
+ img = Image.open(file)
40
+ img = img.convert("RGB")
41
+ out = resize_multiple(img, sizes=sizes, resample=resample)
42
+
43
+ return i, idx, out
44
+
45
+
46
+ def prepare(env,
47
+ paths,
48
+ n_worker,
49
+ sizes=(128, 256, 512, 1024),
50
+ resample=Image.LANCZOS):
51
+ resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
52
+
53
+ # index = filename in int
54
+ indexs = []
55
+ for each in paths:
56
+ file = os.path.basename(each)
57
+ name, ext = file.split('.')
58
+ idx = int(name)
59
+ indexs.append(idx)
60
+
61
+ # sort by file index
62
+ files = sorted(zip(paths, indexs), key=lambda x: x[1])
63
+ files = list(enumerate(files))
64
+ total = 0
65
+
66
+ with multiprocessing.Pool(n_worker) as pool:
67
+ for i, idx, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
68
+ for size, img in zip(sizes, imgs):
69
+ key = f"{size}-{str(idx).zfill(5)}".encode("utf-8")
70
+
71
+ with env.begin(write=True) as txn:
72
+ txn.put(key, img)
73
+
74
+ total += 1
75
+
76
+ with env.begin(write=True) as txn:
77
+ txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
78
+
79
+
80
+ class ImageFolder(Dataset):
81
+ def __init__(self, folder, exts=['jpg']):
82
+ super().__init__()
83
+ self.paths = [
84
+ p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')
85
+ ]
86
+
87
+ def __len__(self):
88
+ return len(self.paths)
89
+
90
+ def __getitem__(self, index):
91
+ path = os.path.join(self.folder, self.paths[index])
92
+ img = Image.open(path)
93
+ return img
94
+
95
+
96
+ if __name__ == "__main__":
97
+ """
98
+ converting celebahq images to lmdb
99
+ """
100
+ num_workers = 16
101
+ in_path = 'datasets/celebahq'
102
+ out_path = 'datasets/celebahq256.lmdb'
103
+
104
+ resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
105
+ resample = resample_map['lanczos']
106
+
107
+ sizes = [256]
108
+
109
+ print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
110
+
111
+ # imgset = datasets.ImageFolder(in_path)
112
+ # imgset = ImageFolder(in_path)
113
+ exts = ['jpg']
114
+ paths = [p for ext in exts for p in Path(f'{in_path}').glob(f'**/*.{ext}')]
115
+
116
+ with lmdb.open(out_path, map_size=1024**4, readahead=False) as env:
117
+ prepare(env, paths, num_workers, sizes=sizes, resample=resample)
data_resize_ffhq.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import multiprocessing
3
+ from functools import partial
4
+ from io import BytesIO
5
+ from pathlib import Path
6
+
7
+ import lmdb
8
+ from PIL import Image
9
+ from torch.utils.data import Dataset
10
+ from torchvision.transforms import functional as trans_fn
11
+ from tqdm import tqdm
12
+ import os
13
+
14
+
15
+ def resize_and_convert(img, size, resample, quality=100):
16
+ img = trans_fn.resize(img, size, resample)
17
+ img = trans_fn.center_crop(img, size)
18
+ buffer = BytesIO()
19
+ img.save(buffer, format="jpeg", quality=quality)
20
+ val = buffer.getvalue()
21
+
22
+ return val
23
+
24
+
25
+ def resize_multiple(img,
26
+ sizes=(128, 256, 512, 1024),
27
+ resample=Image.LANCZOS,
28
+ quality=100):
29
+ imgs = []
30
+
31
+ for size in sizes:
32
+ imgs.append(resize_and_convert(img, size, resample, quality))
33
+
34
+ return imgs
35
+
36
+
37
+ def resize_worker(img_file, sizes, resample):
38
+ i, (file, idx) = img_file
39
+ img = Image.open(file)
40
+ img = img.convert("RGB")
41
+ out = resize_multiple(img, sizes=sizes, resample=resample)
42
+
43
+ return i, idx, out
44
+
45
+
46
+ def prepare(env,
47
+ paths,
48
+ n_worker,
49
+ sizes=(128, 256, 512, 1024),
50
+ resample=Image.LANCZOS):
51
+ resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
52
+
53
+ # index = filename in int
54
+ indexs = []
55
+ for each in paths:
56
+ file = os.path.basename(each)
57
+ name, ext = file.split('.')
58
+ idx = int(name)
59
+ indexs.append(idx)
60
+
61
+ # sort by file index
62
+ files = sorted(zip(paths, indexs), key=lambda x: x[1])
63
+ files = list(enumerate(files))
64
+ total = 0
65
+
66
+ with multiprocessing.Pool(n_worker) as pool:
67
+ for i, idx, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
68
+ for size, img in zip(sizes, imgs):
69
+ key = f"{size}-{str(idx).zfill(5)}".encode("utf-8")
70
+
71
+ with env.begin(write=True) as txn:
72
+ txn.put(key, img)
73
+
74
+ total += 1
75
+
76
+ with env.begin(write=True) as txn:
77
+ txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
78
+
79
+
80
+ class ImageFolder(Dataset):
81
+ def __init__(self, folder, exts=['jpg']):
82
+ super().__init__()
83
+ self.paths = [
84
+ p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')
85
+ ]
86
+
87
+ def __len__(self):
88
+ return len(self.paths)
89
+
90
+ def __getitem__(self, index):
91
+ path = os.path.join(self.folder, self.paths[index])
92
+ img = Image.open(path)
93
+ return img
94
+
95
+
96
+ if __name__ == "__main__":
97
+ """
98
+ converting ffhq images to lmdb
99
+ """
100
+ num_workers = 16
101
+ # original ffhq data path
102
+ in_path = 'datasets/ffhq'
103
+ # target output path
104
+ out_path = 'datasets/ffhq.lmdb'
105
+
106
+ if not os.path.exists(out_path):
107
+ os.makedirs(out_path)
108
+
109
+ resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
110
+ resample = resample_map['lanczos']
111
+
112
+ sizes = [256]
113
+
114
+ print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
115
+
116
+ # imgset = datasets.ImageFolder(in_path)
117
+ # imgset = ImageFolder(in_path)
118
+ exts = ['jpg']
119
+ paths = [p for ext in exts for p in Path(f'{in_path}').glob(f'**/*.{ext}')]
120
+ # print(paths[:10])
121
+
122
+ with lmdb.open(out_path, map_size=1024**4, readahead=False) as env:
123
+ prepare(env, paths, num_workers, sizes=sizes, resample=resample)
data_resize_horse.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import multiprocessing
3
+ import os
4
+ import shutil
5
+ from functools import partial
6
+ from io import BytesIO
7
+ from multiprocessing import Process, Queue
8
+ from os.path import exists, join
9
+
10
+ import lmdb
11
+ from PIL import Image
12
+ from torch.utils.data import DataLoader, Dataset
13
+ from torchvision.datasets import LSUNClass
14
+ from torchvision.transforms import functional as trans_fn
15
+ from tqdm import tqdm
16
+
17
+
18
+ def resize_and_convert(img, size, resample, quality=100):
19
+ img = trans_fn.resize(img, size, resample)
20
+ img = trans_fn.center_crop(img, size)
21
+ buffer = BytesIO()
22
+ img.save(buffer, format="webp", quality=quality)
23
+ val = buffer.getvalue()
24
+
25
+ return val
26
+
27
+
28
+ def resize_multiple(img,
29
+ sizes=(128, 256, 512, 1024),
30
+ resample=Image.LANCZOS,
31
+ quality=100):
32
+ imgs = []
33
+
34
+ for size in sizes:
35
+ imgs.append(resize_and_convert(img, size, resample, quality))
36
+
37
+ return imgs
38
+
39
+
40
+ def resize_worker(idx, img, sizes, resample):
41
+ img = img.convert("RGB")
42
+ out = resize_multiple(img, sizes=sizes, resample=resample)
43
+ return idx, out
44
+
45
+
46
+ class ConvertDataset(Dataset):
47
+ def __init__(self, data) -> None:
48
+ self.data = data
49
+
50
+ def __len__(self):
51
+ return len(self.data)
52
+
53
+ def __getitem__(self, index):
54
+ img, _ = self.data[index]
55
+ bytes = resize_and_convert(img, 256, Image.LANCZOS, quality=90)
56
+ return bytes
57
+
58
+
59
+ if __name__ == "__main__":
60
+ """
61
+ converting lsun' original lmdb to our lmdb, which is somehow more performant.
62
+ """
63
+ from tqdm import tqdm
64
+
65
+ # path to the original lsun's lmdb
66
+ src_path = 'datasets/horse_train_lmdb'
67
+ out_path = 'datasets/horse256.lmdb'
68
+
69
+ dataset = LSUNClass(root=os.path.expanduser(src_path))
70
+ dataset = ConvertDataset(dataset)
71
+ loader = DataLoader(dataset,
72
+ batch_size=50,
73
+ num_workers=16,
74
+ collate_fn=lambda x: x)
75
+
76
+ target = os.path.expanduser(out_path)
77
+ if os.path.exists(target):
78
+ shutil.rmtree(target)
79
+
80
+ with lmdb.open(target, map_size=1024**4, readahead=False) as env:
81
+ with tqdm(total=len(dataset)) as progress:
82
+ i = 0
83
+ for batch in loader:
84
+ with env.begin(write=True) as txn:
85
+ for img in batch:
86
+ key = f"{256}-{str(i).zfill(7)}".encode("utf-8")
87
+ # print(key)
88
+ txn.put(key, img)
89
+ i += 1
90
+ progress.update()
91
+ # if i == 1000:
92
+ # break
93
+ # if total == len(imgset):
94
+ # break
95
+
96
+ with env.begin(write=True) as txn:
97
+ txn.put("length".encode("utf-8"), str(i).encode("utf-8"))
dataset.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from io import BytesIO
3
+ from pathlib import Path
4
+
5
+ import lmdb
6
+ from PIL import Image
7
+ from torch.utils.data import Dataset
8
+ from torchvision import transforms
9
+ from torchvision.datasets import CIFAR10, LSUNClass
10
+ import torch
11
+ import pandas as pd
12
+
13
+ import torchvision.transforms.functional as Ftrans
14
+
15
+
16
+ class ImageDataset(Dataset):
17
+ def __init__(
18
+ self,
19
+ folder,
20
+ image_size,
21
+ exts=['jpg'],
22
+ do_augment: bool = True,
23
+ do_transform: bool = True,
24
+ do_normalize: bool = True,
25
+ sort_names=False,
26
+ has_subdir: bool = True,
27
+ ):
28
+ super().__init__()
29
+ self.folder = folder
30
+ self.image_size = image_size
31
+
32
+ # relative paths (make it shorter, saves memory and faster to sort)
33
+ if has_subdir:
34
+ self.paths = [
35
+ p.relative_to(folder) for ext in exts
36
+ for p in Path(f'{folder}').glob(f'**/*.{ext}')
37
+ ]
38
+ else:
39
+ self.paths = [
40
+ p.relative_to(folder) for ext in exts
41
+ for p in Path(f'{folder}').glob(f'*.{ext}')
42
+ ]
43
+ if sort_names:
44
+ self.paths = sorted(self.paths)
45
+
46
+ transform = [
47
+ transforms.Resize(image_size),
48
+ transforms.CenterCrop(image_size),
49
+ ]
50
+ if do_augment:
51
+ transform.append(transforms.RandomHorizontalFlip())
52
+ if do_transform:
53
+ transform.append(transforms.ToTensor())
54
+ if do_normalize:
55
+ transform.append(
56
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
57
+ self.transform = transforms.Compose(transform)
58
+
59
+ def __len__(self):
60
+ return len(self.paths)
61
+
62
+ def __getitem__(self, index):
63
+ path = os.path.join(self.folder, self.paths[index])
64
+ img = Image.open(path)
65
+ # if the image is 'rgba'!
66
+ img = img.convert('RGB')
67
+ if self.transform is not None:
68
+ img = self.transform(img)
69
+ return {'img': img, 'index': index}
70
+
71
+
72
+ class SubsetDataset(Dataset):
73
+ def __init__(self, dataset, size):
74
+ assert len(dataset) >= size
75
+ self.dataset = dataset
76
+ self.size = size
77
+
78
+ def __len__(self):
79
+ return self.size
80
+
81
+ def __getitem__(self, index):
82
+ assert index < self.size
83
+ return self.dataset[index]
84
+
85
+
86
+ class BaseLMDB(Dataset):
87
+ def __init__(self, path, original_resolution, zfill: int = 5):
88
+ self.original_resolution = original_resolution
89
+ self.zfill = zfill
90
+ self.env = lmdb.open(
91
+ path,
92
+ max_readers=32,
93
+ readonly=True,
94
+ lock=False,
95
+ readahead=False,
96
+ meminit=False,
97
+ )
98
+
99
+ if not self.env:
100
+ raise IOError('Cannot open lmdb dataset', path)
101
+
102
+ with self.env.begin(write=False) as txn:
103
+ self.length = int(
104
+ txn.get('length'.encode('utf-8')).decode('utf-8'))
105
+
106
+ def __len__(self):
107
+ return self.length
108
+
109
+ def __getitem__(self, index):
110
+ with self.env.begin(write=False) as txn:
111
+ key = f'{self.original_resolution}-{str(index).zfill(self.zfill)}'.encode(
112
+ 'utf-8')
113
+ img_bytes = txn.get(key)
114
+
115
+ buffer = BytesIO(img_bytes)
116
+ img = Image.open(buffer)
117
+ return img
118
+
119
+
120
+ def make_transform(
121
+ image_size,
122
+ flip_prob=0.5,
123
+ crop_d2c=False,
124
+ ):
125
+ if crop_d2c:
126
+ transform = [
127
+ d2c_crop(),
128
+ transforms.Resize(image_size),
129
+ ]
130
+ else:
131
+ transform = [
132
+ transforms.Resize(image_size),
133
+ transforms.CenterCrop(image_size),
134
+ ]
135
+ transform.append(transforms.RandomHorizontalFlip(p=flip_prob))
136
+ transform.append(transforms.ToTensor())
137
+ transform.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
138
+ transform = transforms.Compose(transform)
139
+ return transform
140
+
141
+
142
+ class FFHQlmdb(Dataset):
143
+ def __init__(self,
144
+ path=os.path.expanduser('datasets/ffhq256.lmdb'),
145
+ image_size=256,
146
+ original_resolution=256,
147
+ split=None,
148
+ as_tensor: bool = True,
149
+ do_augment: bool = True,
150
+ do_normalize: bool = True,
151
+ **kwargs):
152
+ self.original_resolution = original_resolution
153
+ self.data = BaseLMDB(path, original_resolution, zfill=5)
154
+ self.length = len(self.data)
155
+
156
+ if split is None:
157
+ self.offset = 0
158
+ elif split == 'train':
159
+ # last 60k
160
+ self.length = self.length - 10000
161
+ self.offset = 10000
162
+ elif split == 'test':
163
+ # first 10k
164
+ self.length = 10000
165
+ self.offset = 0
166
+ else:
167
+ raise NotImplementedError()
168
+
169
+ transform = [
170
+ transforms.Resize(image_size),
171
+ ]
172
+ if do_augment:
173
+ transform.append(transforms.RandomHorizontalFlip())
174
+ if as_tensor:
175
+ transform.append(transforms.ToTensor())
176
+ if do_normalize:
177
+ transform.append(
178
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
179
+ self.transform = transforms.Compose(transform)
180
+
181
+ def __len__(self):
182
+ return self.length
183
+
184
+ def __getitem__(self, index):
185
+ assert index < self.length
186
+ index = index + self.offset
187
+ img = self.data[index]
188
+ if self.transform is not None:
189
+ img = self.transform(img)
190
+ return {'img': img, 'index': index}
191
+
192
+
193
+ class Crop:
194
+ def __init__(self, x1, x2, y1, y2):
195
+ self.x1 = x1
196
+ self.x2 = x2
197
+ self.y1 = y1
198
+ self.y2 = y2
199
+
200
+ def __call__(self, img):
201
+ return Ftrans.crop(img, self.x1, self.y1, self.x2 - self.x1,
202
+ self.y2 - self.y1)
203
+
204
+ def __repr__(self):
205
+ return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format(
206
+ self.x1, self.x2, self.y1, self.y2)
207
+
208
+
209
+ def d2c_crop():
210
+ # from D2C paper for CelebA dataset.
211
+ cx = 89
212
+ cy = 121
213
+ x1 = cy - 64
214
+ x2 = cy + 64
215
+ y1 = cx - 64
216
+ y2 = cx + 64
217
+ return Crop(x1, x2, y1, y2)
218
+
219
+
220
+ class CelebAlmdb(Dataset):
221
+ """
222
+ also supports for d2c crop.
223
+ """
224
+ def __init__(self,
225
+ path,
226
+ image_size,
227
+ original_resolution=128,
228
+ split=None,
229
+ as_tensor: bool = True,
230
+ do_augment: bool = True,
231
+ do_normalize: bool = True,
232
+ crop_d2c: bool = False,
233
+ **kwargs):
234
+ self.original_resolution = original_resolution
235
+ self.data = BaseLMDB(path, original_resolution, zfill=7)
236
+ self.length = len(self.data)
237
+ self.crop_d2c = crop_d2c
238
+
239
+ if split is None:
240
+ self.offset = 0
241
+ else:
242
+ raise NotImplementedError()
243
+
244
+ if crop_d2c:
245
+ transform = [
246
+ d2c_crop(),
247
+ transforms.Resize(image_size),
248
+ ]
249
+ else:
250
+ transform = [
251
+ transforms.Resize(image_size),
252
+ transforms.CenterCrop(image_size),
253
+ ]
254
+
255
+ if do_augment:
256
+ transform.append(transforms.RandomHorizontalFlip())
257
+ if as_tensor:
258
+ transform.append(transforms.ToTensor())
259
+ if do_normalize:
260
+ transform.append(
261
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
262
+ self.transform = transforms.Compose(transform)
263
+
264
+ def __len__(self):
265
+ return self.length
266
+
267
+ def __getitem__(self, index):
268
+ assert index < self.length
269
+ index = index + self.offset
270
+ img = self.data[index]
271
+ if self.transform is not None:
272
+ img = self.transform(img)
273
+ return {'img': img, 'index': index}
274
+
275
+
276
+ class Horse_lmdb(Dataset):
277
+ def __init__(self,
278
+ path=os.path.expanduser('datasets/horse256.lmdb'),
279
+ image_size=128,
280
+ original_resolution=256,
281
+ do_augment: bool = True,
282
+ do_transform: bool = True,
283
+ do_normalize: bool = True,
284
+ **kwargs):
285
+ self.original_resolution = original_resolution
286
+ print(path)
287
+ self.data = BaseLMDB(path, original_resolution, zfill=7)
288
+ self.length = len(self.data)
289
+
290
+ transform = [
291
+ transforms.Resize(image_size),
292
+ transforms.CenterCrop(image_size),
293
+ ]
294
+ if do_augment:
295
+ transform.append(transforms.RandomHorizontalFlip())
296
+ if do_transform:
297
+ transform.append(transforms.ToTensor())
298
+ if do_normalize:
299
+ transform.append(
300
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
301
+ self.transform = transforms.Compose(transform)
302
+
303
+ def __len__(self):
304
+ return self.length
305
+
306
+ def __getitem__(self, index):
307
+ img = self.data[index]
308
+ if self.transform is not None:
309
+ img = self.transform(img)
310
+ return {'img': img, 'index': index}
311
+
312
+
313
+ class Bedroom_lmdb(Dataset):
314
+ def __init__(self,
315
+ path=os.path.expanduser('datasets/bedroom256.lmdb'),
316
+ image_size=128,
317
+ original_resolution=256,
318
+ do_augment: bool = True,
319
+ do_transform: bool = True,
320
+ do_normalize: bool = True,
321
+ **kwargs):
322
+ self.original_resolution = original_resolution
323
+ print(path)
324
+ self.data = BaseLMDB(path, original_resolution, zfill=7)
325
+ self.length = len(self.data)
326
+
327
+ transform = [
328
+ transforms.Resize(image_size),
329
+ transforms.CenterCrop(image_size),
330
+ ]
331
+ if do_augment:
332
+ transform.append(transforms.RandomHorizontalFlip())
333
+ if do_transform:
334
+ transform.append(transforms.ToTensor())
335
+ if do_normalize:
336
+ transform.append(
337
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
338
+ self.transform = transforms.Compose(transform)
339
+
340
+ def __len__(self):
341
+ return self.length
342
+
343
+ def __getitem__(self, index):
344
+ img = self.data[index]
345
+ img = self.transform(img)
346
+ return {'img': img, 'index': index}
347
+
348
+
349
+ class CelebAttrDataset(Dataset):
350
+
351
+ id_to_cls = [
352
+ '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
353
+ 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
354
+ 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
355
+ 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
356
+ 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
357
+ 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline',
358
+ 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair',
359
+ 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick',
360
+ 'Wearing_Necklace', 'Wearing_Necktie', 'Young'
361
+ ]
362
+ cls_to_id = {v: k for k, v in enumerate(id_to_cls)}
363
+
364
+ def __init__(self,
365
+ folder,
366
+ image_size=64,
367
+ attr_path=os.path.expanduser(
368
+ 'datasets/celeba_anno/list_attr_celeba.txt'),
369
+ ext='png',
370
+ only_cls_name: str = None,
371
+ only_cls_value: int = None,
372
+ do_augment: bool = False,
373
+ do_transform: bool = True,
374
+ do_normalize: bool = True,
375
+ d2c: bool = False):
376
+ super().__init__()
377
+ self.folder = folder
378
+ self.image_size = image_size
379
+ self.ext = ext
380
+
381
+ # relative paths (make it shorter, saves memory and faster to sort)
382
+ paths = [
383
+ str(p.relative_to(folder))
384
+ for p in Path(f'{folder}').glob(f'**/*.{ext}')
385
+ ]
386
+ paths = [str(each).split('.')[0] + '.jpg' for each in paths]
387
+
388
+ if d2c:
389
+ transform = [
390
+ d2c_crop(),
391
+ transforms.Resize(image_size),
392
+ ]
393
+ else:
394
+ transform = [
395
+ transforms.Resize(image_size),
396
+ transforms.CenterCrop(image_size),
397
+ ]
398
+ if do_augment:
399
+ transform.append(transforms.RandomHorizontalFlip())
400
+ if do_transform:
401
+ transform.append(transforms.ToTensor())
402
+ if do_normalize:
403
+ transform.append(
404
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
405
+ self.transform = transforms.Compose(transform)
406
+
407
+ with open(attr_path) as f:
408
+ # discard the top line
409
+ f.readline()
410
+ self.df = pd.read_csv(f, delim_whitespace=True)
411
+ self.df = self.df[self.df.index.isin(paths)]
412
+
413
+ if only_cls_name is not None:
414
+ self.df = self.df[self.df[only_cls_name] == only_cls_value]
415
+
416
+ def pos_count(self, cls_name):
417
+ return (self.df[cls_name] == 1).sum()
418
+
419
+ def neg_count(self, cls_name):
420
+ return (self.df[cls_name] == -1).sum()
421
+
422
+ def __len__(self):
423
+ return len(self.df)
424
+
425
+ def __getitem__(self, index):
426
+ row = self.df.iloc[index]
427
+ name = row.name.split('.')[0]
428
+ name = f'{name}.{self.ext}'
429
+
430
+ path = os.path.join(self.folder, name)
431
+ img = Image.open(path)
432
+
433
+ labels = [0] * len(self.id_to_cls)
434
+ for k, v in row.items():
435
+ labels[self.cls_to_id[k]] = int(v)
436
+
437
+ if self.transform is not None:
438
+ img = self.transform(img)
439
+
440
+ return {'img': img, 'index': index, 'labels': torch.tensor(labels)}
441
+
442
+
443
+ class CelebD2CAttrDataset(CelebAttrDataset):
444
+ """
445
+ the dataset is used in the D2C paper.
446
+ it has a specific crop from the original CelebA.
447
+ """
448
+ def __init__(self,
449
+ folder,
450
+ image_size=64,
451
+ attr_path=os.path.expanduser(
452
+ 'datasets/celeba_anno/list_attr_celeba.txt'),
453
+ ext='jpg',
454
+ only_cls_name: str = None,
455
+ only_cls_value: int = None,
456
+ do_augment: bool = False,
457
+ do_transform: bool = True,
458
+ do_normalize: bool = True,
459
+ d2c: bool = True):
460
+ super().__init__(folder,
461
+ image_size,
462
+ attr_path,
463
+ ext=ext,
464
+ only_cls_name=only_cls_name,
465
+ only_cls_value=only_cls_value,
466
+ do_augment=do_augment,
467
+ do_transform=do_transform,
468
+ do_normalize=do_normalize,
469
+ d2c=d2c)
470
+
471
+
472
+ class CelebAttrFewshotDataset(Dataset):
473
+ def __init__(
474
+ self,
475
+ cls_name,
476
+ K,
477
+ img_folder,
478
+ img_size=64,
479
+ ext='png',
480
+ seed=0,
481
+ only_cls_name: str = None,
482
+ only_cls_value: int = None,
483
+ all_neg: bool = False,
484
+ do_augment: bool = False,
485
+ do_transform: bool = True,
486
+ do_normalize: bool = True,
487
+ d2c: bool = False,
488
+ ) -> None:
489
+ self.cls_name = cls_name
490
+ self.K = K
491
+ self.img_folder = img_folder
492
+ self.ext = ext
493
+
494
+ if all_neg:
495
+ path = f'data/celeba_fewshots/K{K}_allneg_{cls_name}_{seed}.csv'
496
+ else:
497
+ path = f'data/celeba_fewshots/K{K}_{cls_name}_{seed}.csv'
498
+ self.df = pd.read_csv(path, index_col=0)
499
+ if only_cls_name is not None:
500
+ self.df = self.df[self.df[only_cls_name] == only_cls_value]
501
+
502
+ if d2c:
503
+ transform = [
504
+ d2c_crop(),
505
+ transforms.Resize(img_size),
506
+ ]
507
+ else:
508
+ transform = [
509
+ transforms.Resize(img_size),
510
+ transforms.CenterCrop(img_size),
511
+ ]
512
+ if do_augment:
513
+ transform.append(transforms.RandomHorizontalFlip())
514
+ if do_transform:
515
+ transform.append(transforms.ToTensor())
516
+ if do_normalize:
517
+ transform.append(
518
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
519
+ self.transform = transforms.Compose(transform)
520
+
521
+ def pos_count(self, cls_name):
522
+ return (self.df[cls_name] == 1).sum()
523
+
524
+ def neg_count(self, cls_name):
525
+ return (self.df[cls_name] == -1).sum()
526
+
527
+ def __len__(self):
528
+ return len(self.df)
529
+
530
+ def __getitem__(self, index):
531
+ row = self.df.iloc[index]
532
+ name = row.name.split('.')[0]
533
+ name = f'{name}.{self.ext}'
534
+
535
+ path = os.path.join(self.img_folder, name)
536
+ img = Image.open(path)
537
+
538
+ # (1, 1)
539
+ label = torch.tensor(int(row[self.cls_name])).unsqueeze(-1)
540
+
541
+ if self.transform is not None:
542
+ img = self.transform(img)
543
+
544
+ return {'img': img, 'index': index, 'labels': label}
545
+
546
+
547
+ class CelebD2CAttrFewshotDataset(CelebAttrFewshotDataset):
548
+ def __init__(self,
549
+ cls_name,
550
+ K,
551
+ img_folder,
552
+ img_size=64,
553
+ ext='jpg',
554
+ seed=0,
555
+ only_cls_name: str = None,
556
+ only_cls_value: int = None,
557
+ all_neg: bool = False,
558
+ do_augment: bool = False,
559
+ do_transform: bool = True,
560
+ do_normalize: bool = True,
561
+ is_negative=False,
562
+ d2c: bool = True) -> None:
563
+ super().__init__(cls_name,
564
+ K,
565
+ img_folder,
566
+ img_size,
567
+ ext=ext,
568
+ seed=seed,
569
+ only_cls_name=only_cls_name,
570
+ only_cls_value=only_cls_value,
571
+ all_neg=all_neg,
572
+ do_augment=do_augment,
573
+ do_transform=do_transform,
574
+ do_normalize=do_normalize,
575
+ d2c=d2c)
576
+ self.is_negative = is_negative
577
+
578
+
579
+ class CelebHQAttrDataset(Dataset):
580
+ id_to_cls = [
581
+ '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
582
+ 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
583
+ 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
584
+ 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
585
+ 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
586
+ 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline',
587
+ 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair',
588
+ 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick',
589
+ 'Wearing_Necklace', 'Wearing_Necktie', 'Young'
590
+ ]
591
+ cls_to_id = {v: k for k, v in enumerate(id_to_cls)}
592
+
593
+ def __init__(self,
594
+ path=os.path.expanduser('datasets/celebahq256.lmdb'),
595
+ image_size=None,
596
+ attr_path=os.path.expanduser(
597
+ 'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'),
598
+ original_resolution=256,
599
+ do_augment: bool = False,
600
+ do_transform: bool = True,
601
+ do_normalize: bool = True):
602
+ super().__init__()
603
+ self.image_size = image_size
604
+ self.data = BaseLMDB(path, original_resolution, zfill=5)
605
+
606
+ transform = [
607
+ transforms.Resize(image_size),
608
+ transforms.CenterCrop(image_size),
609
+ ]
610
+ if do_augment:
611
+ transform.append(transforms.RandomHorizontalFlip())
612
+ if do_transform:
613
+ transform.append(transforms.ToTensor())
614
+ if do_normalize:
615
+ transform.append(
616
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
617
+ self.transform = transforms.Compose(transform)
618
+
619
+ with open(attr_path) as f:
620
+ # discard the top line
621
+ f.readline()
622
+ self.df = pd.read_csv(f, delim_whitespace=True)
623
+
624
+ def pos_count(self, cls_name):
625
+ return (self.df[cls_name] == 1).sum()
626
+
627
+ def neg_count(self, cls_name):
628
+ return (self.df[cls_name] == -1).sum()
629
+
630
+ def __len__(self):
631
+ return len(self.df)
632
+
633
+ def __getitem__(self, index):
634
+ row = self.df.iloc[index]
635
+ img_name = row.name
636
+ img_idx, ext = img_name.split('.')
637
+ img = self.data[img_idx]
638
+
639
+ labels = [0] * len(self.id_to_cls)
640
+ for k, v in row.items():
641
+ labels[self.cls_to_id[k]] = int(v)
642
+
643
+ if self.transform is not None:
644
+ img = self.transform(img)
645
+ return {'img': img, 'index': index, 'labels': torch.tensor(labels)}
646
+
647
+
648
+ class CelebHQAttrFewshotDataset(Dataset):
649
+ def __init__(self,
650
+ cls_name,
651
+ K,
652
+ path,
653
+ image_size,
654
+ original_resolution=256,
655
+ do_augment: bool = False,
656
+ do_transform: bool = True,
657
+ do_normalize: bool = True):
658
+ super().__init__()
659
+ self.image_size = image_size
660
+ self.cls_name = cls_name
661
+ self.K = K
662
+ self.data = BaseLMDB(path, original_resolution, zfill=5)
663
+
664
+ transform = [
665
+ transforms.Resize(image_size),
666
+ transforms.CenterCrop(image_size),
667
+ ]
668
+ if do_augment:
669
+ transform.append(transforms.RandomHorizontalFlip())
670
+ if do_transform:
671
+ transform.append(transforms.ToTensor())
672
+ if do_normalize:
673
+ transform.append(
674
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
675
+ self.transform = transforms.Compose(transform)
676
+
677
+ self.df = pd.read_csv(f'data/celebahq_fewshots/K{K}_{cls_name}.csv',
678
+ index_col=0)
679
+
680
+ def pos_count(self, cls_name):
681
+ return (self.df[cls_name] == 1).sum()
682
+
683
+ def neg_count(self, cls_name):
684
+ return (self.df[cls_name] == -1).sum()
685
+
686
+ def __len__(self):
687
+ return len(self.df)
688
+
689
+ def __getitem__(self, index):
690
+ row = self.df.iloc[index]
691
+ img_name = row.name
692
+ img_idx, ext = img_name.split('.')
693
+ img = self.data[img_idx]
694
+
695
+ # (1, 1)
696
+ label = torch.tensor(int(row[self.cls_name])).unsqueeze(-1)
697
+
698
+ if self.transform is not None:
699
+ img = self.transform(img)
700
+
701
+ return {'img': img, 'index': index, 'labels': label}
702
+
703
+
704
+ class Repeat(Dataset):
705
+ def __init__(self, dataset, new_len) -> None:
706
+ super().__init__()
707
+ self.dataset = dataset
708
+ self.original_len = len(dataset)
709
+ self.new_len = new_len
710
+
711
+ def __len__(self):
712
+ return self.new_len
713
+
714
+ def __getitem__(self, index):
715
+ index = index % self.original_len
716
+ return self.dataset[index]
dataset_util.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import os
3
+ from dist_utils import *
4
+
5
+
6
+ def use_cached_dataset_path(source_path, cache_path):
7
+ if get_rank() == 0:
8
+ if not os.path.exists(cache_path):
9
+ # shutil.rmtree(cache_path)
10
+ print(f'copying the data: {source_path} to {cache_path}')
11
+ shutil.copytree(source_path, cache_path)
12
+ barrier()
13
+ return cache_path
datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt ADDED
The diff for this file is too large to render. See raw diff
 
datasets/celeba_anno/CelebAMask-HQ-pose-anno.txt ADDED
The diff for this file is too large to render. See raw diff
 
datasets/celeba_anno/list_attr_celeba.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0
3
+ size 26721026
diffusion/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ from .diffusion import SpacedDiffusionBeatGans, SpacedDiffusionBeatGansConfig
4
+
5
+ Sampler = Union[SpacedDiffusionBeatGans]
6
+ SamplerConfig = Union[SpacedDiffusionBeatGansConfig]
diffusion/base.py ADDED
@@ -0,0 +1,1109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code started out as a PyTorch port of Ho et al's diffusion models:
3
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
4
+
5
+ Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
6
+ """
7
+
8
+ from model.unet_autoenc import AutoencReturn
9
+ from config_base import BaseConfig
10
+ import enum
11
+ import math
12
+
13
+ import numpy as np
14
+ import torch as th
15
+ from model import *
16
+ from model.nn import mean_flat
17
+ from typing import NamedTuple, Tuple
18
+ from choices import *
19
+ from torch.cuda.amp import autocast
20
+ import torch.nn.functional as F
21
+
22
+ from dataclasses import dataclass
23
+
24
+
25
+ @dataclass
26
+ class GaussianDiffusionBeatGansConfig(BaseConfig):
27
+ gen_type: GenerativeType
28
+ betas: Tuple[float]
29
+ model_type: ModelType
30
+ model_mean_type: ModelMeanType
31
+ model_var_type: ModelVarType
32
+ loss_type: LossType
33
+ rescale_timesteps: bool
34
+ fp16: bool
35
+ train_pred_xstart_detach: bool = True
36
+
37
+ def make_sampler(self):
38
+ return GaussianDiffusionBeatGans(self)
39
+
40
+
41
+ class GaussianDiffusionBeatGans:
42
+ """
43
+ Utilities for training and sampling diffusion models.
44
+
45
+ Ported directly from here, and then adapted over time to further experimentation.
46
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
47
+
48
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
49
+ starting at T and going to 1.
50
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
51
+ :param model_var_type: a ModelVarType determining how variance is output.
52
+ :param loss_type: a LossType determining the loss function to use.
53
+ :param rescale_timesteps: if True, pass floating point timesteps into the
54
+ model so that they are always scaled like in the
55
+ original paper (0 to 1000).
56
+ """
57
+ def __init__(self, conf: GaussianDiffusionBeatGansConfig):
58
+ self.conf = conf
59
+ self.model_mean_type = conf.model_mean_type
60
+ self.model_var_type = conf.model_var_type
61
+ self.loss_type = conf.loss_type
62
+ self.rescale_timesteps = conf.rescale_timesteps
63
+
64
+ # Use float64 for accuracy.
65
+ betas = np.array(conf.betas, dtype=np.float64)
66
+ self.betas = betas
67
+ assert len(betas.shape) == 1, "betas must be 1-D"
68
+ assert (betas > 0).all() and (betas <= 1).all()
69
+
70
+ self.num_timesteps = int(betas.shape[0])
71
+
72
+ alphas = 1.0 - betas
73
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
74
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
75
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
76
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps, )
77
+
78
+ # calculations for diffusion q(x_t | x_{t-1}) and others
79
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
80
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
81
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
82
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
83
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod -
84
+ 1)
85
+
86
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
87
+ self.posterior_variance = (betas * (1.0 - self.alphas_cumprod_prev) /
88
+ (1.0 - self.alphas_cumprod))
89
+ # log calculation clipped because the posterior variance is 0 at the
90
+ # beginning of the diffusion chain.
91
+ self.posterior_log_variance_clipped = np.log(
92
+ np.append(self.posterior_variance[1], self.posterior_variance[1:]))
93
+ self.posterior_mean_coef1 = (betas *
94
+ np.sqrt(self.alphas_cumprod_prev) /
95
+ (1.0 - self.alphas_cumprod))
96
+ self.posterior_mean_coef2 = ((1.0 - self.alphas_cumprod_prev) *
97
+ np.sqrt(alphas) /
98
+ (1.0 - self.alphas_cumprod))
99
+
100
+ def training_losses(self,
101
+ model: Model,
102
+ x_start: th.Tensor,
103
+ t: th.Tensor,
104
+ model_kwargs=None,
105
+ noise: th.Tensor = None):
106
+ """
107
+ Compute training losses for a single timestep.
108
+
109
+ :param model: the model to evaluate loss on.
110
+ :param x_start: the [N x C x ...] tensor of inputs.
111
+ :param t: a batch of timestep indices.
112
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
113
+ pass to the model. This can be used for conditioning.
114
+ :param noise: if specified, the specific Gaussian noise to try to remove.
115
+ :return: a dict with the key "loss" containing a tensor of shape [N].
116
+ Some mean or variance settings may also have other keys.
117
+ """
118
+ if model_kwargs is None:
119
+ model_kwargs = {}
120
+ if noise is None:
121
+ noise = th.randn_like(x_start)
122
+
123
+ x_t = self.q_sample(x_start, t, noise=noise)
124
+
125
+ terms = {'x_t': x_t}
126
+
127
+ if self.loss_type in [
128
+ LossType.mse,
129
+ LossType.l1,
130
+ ]:
131
+ with autocast(self.conf.fp16):
132
+ # x_t is static wrt. to the diffusion process
133
+ model_forward = model.forward(x=x_t.detach(),
134
+ t=self._scale_timesteps(t),
135
+ x_start=x_start.detach(),
136
+ **model_kwargs)
137
+ model_output = model_forward.pred
138
+
139
+ _model_output = model_output
140
+ if self.conf.train_pred_xstart_detach:
141
+ _model_output = _model_output.detach()
142
+ # get the pred xstart
143
+ p_mean_var = self.p_mean_variance(
144
+ model=DummyModel(pred=_model_output),
145
+ # gradient goes through x_t
146
+ x=x_t,
147
+ t=t,
148
+ clip_denoised=False)
149
+ terms['pred_xstart'] = p_mean_var['pred_xstart']
150
+
151
+ # model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
152
+
153
+ target_types = {
154
+ ModelMeanType.eps: noise,
155
+ }
156
+ target = target_types[self.model_mean_type]
157
+ assert model_output.shape == target.shape == x_start.shape
158
+
159
+ if self.loss_type == LossType.mse:
160
+ if self.model_mean_type == ModelMeanType.eps:
161
+ # (n, c, h, w) => (n, )
162
+ terms["mse"] = mean_flat((target - model_output)**2)
163
+ else:
164
+ raise NotImplementedError()
165
+ elif self.loss_type == LossType.l1:
166
+ # (n, c, h, w) => (n, )
167
+ terms["mse"] = mean_flat((target - model_output).abs())
168
+ else:
169
+ raise NotImplementedError()
170
+
171
+ if "vb" in terms:
172
+ # if learning the variance also use the vlb loss
173
+ terms["loss"] = terms["mse"] + terms["vb"]
174
+ else:
175
+ terms["loss"] = terms["mse"]
176
+ else:
177
+ raise NotImplementedError(self.loss_type)
178
+
179
+ return terms
180
+
181
+ def sample(self,
182
+ model: Model,
183
+ shape=None,
184
+ noise=None,
185
+ cond=None,
186
+ x_start=None,
187
+ clip_denoised=True,
188
+ model_kwargs=None,
189
+ progress=False):
190
+ """
191
+ Args:
192
+ x_start: given for the autoencoder
193
+ """
194
+ if model_kwargs is None:
195
+ model_kwargs = {}
196
+ if self.conf.model_type.has_autoenc():
197
+ model_kwargs['x_start'] = x_start
198
+ model_kwargs['cond'] = cond
199
+
200
+ if self.conf.gen_type == GenerativeType.ddpm:
201
+ return self.p_sample_loop(model,
202
+ shape=shape,
203
+ noise=noise,
204
+ clip_denoised=clip_denoised,
205
+ model_kwargs=model_kwargs,
206
+ progress=progress)
207
+ elif self.conf.gen_type == GenerativeType.ddim:
208
+ return self.ddim_sample_loop(model,
209
+ shape=shape,
210
+ noise=noise,
211
+ clip_denoised=clip_denoised,
212
+ model_kwargs=model_kwargs,
213
+ progress=progress)
214
+ else:
215
+ raise NotImplementedError()
216
+
217
+ def q_mean_variance(self, x_start, t):
218
+ """
219
+ Get the distribution q(x_t | x_0).
220
+
221
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
222
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
223
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
224
+ """
225
+ mean = (
226
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) *
227
+ x_start)
228
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t,
229
+ x_start.shape)
230
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod,
231
+ t, x_start.shape)
232
+ return mean, variance, log_variance
233
+
234
+ def q_sample(self, x_start, t, noise=None):
235
+ """
236
+ Diffuse the data for a given number of diffusion steps.
237
+
238
+ In other words, sample from q(x_t | x_0).
239
+
240
+ :param x_start: the initial data batch.
241
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
242
+ :param noise: if specified, the split-out normal noise.
243
+ :return: A noisy version of x_start.
244
+ """
245
+ if noise is None:
246
+ noise = th.randn_like(x_start)
247
+ assert noise.shape == x_start.shape
248
+ return (
249
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) *
250
+ x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod,
251
+ t, x_start.shape) * noise)
252
+
253
+ def q_posterior_mean_variance(self, x_start, x_t, t):
254
+ """
255
+ Compute the mean and variance of the diffusion posterior:
256
+
257
+ q(x_{t-1} | x_t, x_0)
258
+
259
+ """
260
+ assert x_start.shape == x_t.shape
261
+ posterior_mean = (
262
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) *
263
+ x_start +
264
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) *
265
+ x_t)
266
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t,
267
+ x_t.shape)
268
+ posterior_log_variance_clipped = _extract_into_tensor(
269
+ self.posterior_log_variance_clipped, t, x_t.shape)
270
+ assert (posterior_mean.shape[0] == posterior_variance.shape[0] ==
271
+ posterior_log_variance_clipped.shape[0] == x_start.shape[0])
272
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
273
+
274
+ def p_mean_variance(self,
275
+ model: Model,
276
+ x,
277
+ t,
278
+ clip_denoised=True,
279
+ denoised_fn=None,
280
+ model_kwargs=None):
281
+ """
282
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
283
+ the initial x, x_0.
284
+
285
+ :param model: the model, which takes a signal and a batch of timesteps
286
+ as input.
287
+ :param x: the [N x C x ...] tensor at time t.
288
+ :param t: a 1-D Tensor of timesteps.
289
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
290
+ :param denoised_fn: if not None, a function which applies to the
291
+ x_start prediction before it is used to sample. Applies before
292
+ clip_denoised.
293
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
294
+ pass to the model. This can be used for conditioning.
295
+ :return: a dict with the following keys:
296
+ - 'mean': the model mean output.
297
+ - 'variance': the model variance output.
298
+ - 'log_variance': the log of 'variance'.
299
+ - 'pred_xstart': the prediction for x_0.
300
+ """
301
+ if model_kwargs is None:
302
+ model_kwargs = {}
303
+
304
+ B, C = x.shape[:2]
305
+ assert t.shape == (B, )
306
+ with autocast(self.conf.fp16):
307
+ model_forward = model.forward(x=x,
308
+ t=self._scale_timesteps(t),
309
+ **model_kwargs)
310
+ model_output = model_forward.pred
311
+
312
+ if self.model_var_type in [
313
+ ModelVarType.fixed_large, ModelVarType.fixed_small
314
+ ]:
315
+ model_variance, model_log_variance = {
316
+ # for fixedlarge, we set the initial (log-)variance like so
317
+ # to get a better decoder log likelihood.
318
+ ModelVarType.fixed_large: (
319
+ np.append(self.posterior_variance[1], self.betas[1:]),
320
+ np.log(
321
+ np.append(self.posterior_variance[1], self.betas[1:])),
322
+ ),
323
+ ModelVarType.fixed_small: (
324
+ self.posterior_variance,
325
+ self.posterior_log_variance_clipped,
326
+ ),
327
+ }[self.model_var_type]
328
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
329
+ model_log_variance = _extract_into_tensor(model_log_variance, t,
330
+ x.shape)
331
+
332
+ def process_xstart(x):
333
+ if denoised_fn is not None:
334
+ x = denoised_fn(x)
335
+ if clip_denoised:
336
+ return x.clamp(-1, 1)
337
+ return x
338
+
339
+ if self.model_mean_type in [
340
+ ModelMeanType.eps,
341
+ ]:
342
+ if self.model_mean_type == ModelMeanType.eps:
343
+ pred_xstart = process_xstart(
344
+ self._predict_xstart_from_eps(x_t=x, t=t,
345
+ eps=model_output))
346
+ else:
347
+ raise NotImplementedError()
348
+ model_mean, _, _ = self.q_posterior_mean_variance(
349
+ x_start=pred_xstart, x_t=x, t=t)
350
+ else:
351
+ raise NotImplementedError(self.model_mean_type)
352
+
353
+ assert (model_mean.shape == model_log_variance.shape ==
354
+ pred_xstart.shape == x.shape)
355
+ return {
356
+ "mean": model_mean,
357
+ "variance": model_variance,
358
+ "log_variance": model_log_variance,
359
+ "pred_xstart": pred_xstart,
360
+ 'model_forward': model_forward,
361
+ }
362
+
363
+ def _predict_xstart_from_eps(self, x_t, t, eps):
364
+ assert x_t.shape == eps.shape
365
+ return (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t,
366
+ x_t.shape) * x_t -
367
+ _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t,
368
+ x_t.shape) * eps)
369
+
370
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
371
+ assert x_t.shape == xprev.shape
372
+ return ( # (xprev - coef2*x_t) / coef1
373
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape)
374
+ * xprev - _extract_into_tensor(
375
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
376
+ x_t.shape) * x_t)
377
+
378
+ def _predict_xstart_from_scaled_xstart(self, t, scaled_xstart):
379
+ return scaled_xstart * _extract_into_tensor(
380
+ self.sqrt_recip_alphas_cumprod, t, scaled_xstart.shape)
381
+
382
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
383
+ return (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t,
384
+ x_t.shape) * x_t -
385
+ pred_xstart) / _extract_into_tensor(
386
+ self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
387
+
388
+ def _predict_eps_from_scaled_xstart(self, x_t, t, scaled_xstart):
389
+ """
390
+ Args:
391
+ scaled_xstart: is supposed to be sqrt(alphacum) * x_0
392
+ """
393
+ # 1 / sqrt(1-alphabar) * (x_t - scaled xstart)
394
+ return (x_t - scaled_xstart) / _extract_into_tensor(
395
+ self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
396
+
397
+ def _scale_timesteps(self, t):
398
+ if self.rescale_timesteps:
399
+ # scale t to be maxed out at 1000 steps
400
+ return t.float() * (1000.0 / self.num_timesteps)
401
+ return t
402
+
403
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
404
+ """
405
+ Compute the mean for the previous step, given a function cond_fn that
406
+ computes the gradient of a conditional log probability with respect to
407
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
408
+ condition on y.
409
+
410
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
411
+ """
412
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
413
+ new_mean = (p_mean_var["mean"].float() +
414
+ p_mean_var["variance"] * gradient.float())
415
+ return new_mean
416
+
417
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
418
+ """
419
+ Compute what the p_mean_variance output would have been, should the
420
+ model's score function be conditioned by cond_fn.
421
+
422
+ See condition_mean() for details on cond_fn.
423
+
424
+ Unlike condition_mean(), this instead uses the conditioning strategy
425
+ from Song et al (2020).
426
+ """
427
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
428
+
429
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
430
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
431
+ x, self._scale_timesteps(t), **model_kwargs)
432
+
433
+ out = p_mean_var.copy()
434
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
435
+ out["mean"], _, _ = self.q_posterior_mean_variance(
436
+ x_start=out["pred_xstart"], x_t=x, t=t)
437
+ return out
438
+
439
+ def p_sample(
440
+ self,
441
+ model: Model,
442
+ x,
443
+ t,
444
+ clip_denoised=True,
445
+ denoised_fn=None,
446
+ cond_fn=None,
447
+ model_kwargs=None,
448
+ ):
449
+ """
450
+ Sample x_{t-1} from the model at the given timestep.
451
+
452
+ :param model: the model to sample from.
453
+ :param x: the current tensor at x_{t-1}.
454
+ :param t: the value of t, starting at 0 for the first diffusion step.
455
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
456
+ :param denoised_fn: if not None, a function which applies to the
457
+ x_start prediction before it is used to sample.
458
+ :param cond_fn: if not None, this is a gradient function that acts
459
+ similarly to the model.
460
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
461
+ pass to the model. This can be used for conditioning.
462
+ :return: a dict containing the following keys:
463
+ - 'sample': a random sample from the model.
464
+ - 'pred_xstart': a prediction of x_0.
465
+ """
466
+ out = self.p_mean_variance(
467
+ model,
468
+ x,
469
+ t,
470
+ clip_denoised=clip_denoised,
471
+ denoised_fn=denoised_fn,
472
+ model_kwargs=model_kwargs,
473
+ )
474
+ noise = th.randn_like(x)
475
+ nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
476
+ ) # no noise when t == 0
477
+ if cond_fn is not None:
478
+ out["mean"] = self.condition_mean(cond_fn,
479
+ out,
480
+ x,
481
+ t,
482
+ model_kwargs=model_kwargs)
483
+ sample = out["mean"] + nonzero_mask * th.exp(
484
+ 0.5 * out["log_variance"]) * noise
485
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
486
+
487
+ def p_sample_loop(
488
+ self,
489
+ model: Model,
490
+ shape=None,
491
+ noise=None,
492
+ clip_denoised=True,
493
+ denoised_fn=None,
494
+ cond_fn=None,
495
+ model_kwargs=None,
496
+ device=None,
497
+ progress=False,
498
+ ):
499
+ """
500
+ Generate samples from the model.
501
+
502
+ :param model: the model module.
503
+ :param shape: the shape of the samples, (N, C, H, W).
504
+ :param noise: if specified, the noise from the encoder to sample.
505
+ Should be of the same shape as `shape`.
506
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
507
+ :param denoised_fn: if not None, a function which applies to the
508
+ x_start prediction before it is used to sample.
509
+ :param cond_fn: if not None, this is a gradient function that acts
510
+ similarly to the model.
511
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
512
+ pass to the model. This can be used for conditioning.
513
+ :param device: if specified, the device to create the samples on.
514
+ If not specified, use a model parameter's device.
515
+ :param progress: if True, show a tqdm progress bar.
516
+ :return: a non-differentiable batch of samples.
517
+ """
518
+ final = None
519
+ for sample in self.p_sample_loop_progressive(
520
+ model,
521
+ shape,
522
+ noise=noise,
523
+ clip_denoised=clip_denoised,
524
+ denoised_fn=denoised_fn,
525
+ cond_fn=cond_fn,
526
+ model_kwargs=model_kwargs,
527
+ device=device,
528
+ progress=progress,
529
+ ):
530
+ final = sample
531
+ return final["sample"]
532
+
533
+ def p_sample_loop_progressive(
534
+ self,
535
+ model: Model,
536
+ shape=None,
537
+ noise=None,
538
+ clip_denoised=True,
539
+ denoised_fn=None,
540
+ cond_fn=None,
541
+ model_kwargs=None,
542
+ device=None,
543
+ progress=False,
544
+ ):
545
+ """
546
+ Generate samples from the model and yield intermediate samples from
547
+ each timestep of diffusion.
548
+
549
+ Arguments are the same as p_sample_loop().
550
+ Returns a generator over dicts, where each dict is the return value of
551
+ p_sample().
552
+ """
553
+ if device is None:
554
+ device = next(model.parameters()).device
555
+ if noise is not None:
556
+ img = noise
557
+ else:
558
+ assert isinstance(shape, (tuple, list))
559
+ img = th.randn(*shape, device=device)
560
+ indices = list(range(self.num_timesteps))[::-1]
561
+
562
+ if progress:
563
+ # Lazy import so that we don't depend on tqdm.
564
+ from tqdm.auto import tqdm
565
+
566
+ indices = tqdm(indices)
567
+
568
+ for i in indices:
569
+ # t = th.tensor([i] * shape[0], device=device)
570
+ t = th.tensor([i] * len(img), device=device)
571
+ with th.no_grad():
572
+ out = self.p_sample(
573
+ model,
574
+ img,
575
+ t,
576
+ clip_denoised=clip_denoised,
577
+ denoised_fn=denoised_fn,
578
+ cond_fn=cond_fn,
579
+ model_kwargs=model_kwargs,
580
+ )
581
+ yield out
582
+ img = out["sample"]
583
+
584
+ def ddim_sample(
585
+ self,
586
+ model: Model,
587
+ x,
588
+ t,
589
+ clip_denoised=True,
590
+ denoised_fn=None,
591
+ cond_fn=None,
592
+ model_kwargs=None,
593
+ eta=0.0,
594
+ ):
595
+ """
596
+ Sample x_{t-1} from the model using DDIM.
597
+
598
+ Same usage as p_sample().
599
+ """
600
+ out = self.p_mean_variance(
601
+ model,
602
+ x,
603
+ t,
604
+ clip_denoised=clip_denoised,
605
+ denoised_fn=denoised_fn,
606
+ model_kwargs=model_kwargs,
607
+ )
608
+ if cond_fn is not None:
609
+ out = self.condition_score(cond_fn,
610
+ out,
611
+ x,
612
+ t,
613
+ model_kwargs=model_kwargs)
614
+
615
+ # Usually our model outputs epsilon, but we re-derive it
616
+ # in case we used x_start or x_prev prediction.
617
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
618
+
619
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
620
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t,
621
+ x.shape)
622
+ sigma = (eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) *
623
+ th.sqrt(1 - alpha_bar / alpha_bar_prev))
624
+ # Equation 12.
625
+ noise = th.randn_like(x)
626
+ mean_pred = (out["pred_xstart"] * th.sqrt(alpha_bar_prev) +
627
+ th.sqrt(1 - alpha_bar_prev - sigma**2) * eps)
628
+ nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
629
+ ) # no noise when t == 0
630
+ sample = mean_pred + nonzero_mask * sigma * noise
631
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
632
+
633
+ def ddim_reverse_sample(
634
+ self,
635
+ model: Model,
636
+ x,
637
+ t,
638
+ clip_denoised=True,
639
+ denoised_fn=None,
640
+ model_kwargs=None,
641
+ eta=0.0,
642
+ ):
643
+ """
644
+ Sample x_{t+1} from the model using DDIM reverse ODE.
645
+ NOTE: never used ?
646
+ """
647
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
648
+ out = self.p_mean_variance(
649
+ model,
650
+ x,
651
+ t,
652
+ clip_denoised=clip_denoised,
653
+ denoised_fn=denoised_fn,
654
+ model_kwargs=model_kwargs,
655
+ )
656
+ # Usually our model outputs epsilon, but we re-derive it
657
+ # in case we used x_start or x_prev prediction.
658
+ eps = (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape)
659
+ * x - out["pred_xstart"]) / _extract_into_tensor(
660
+ self.sqrt_recipm1_alphas_cumprod, t, x.shape)
661
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t,
662
+ x.shape)
663
+
664
+ # Equation 12. reversed (DDIM paper) (th.sqrt == torch.sqrt)
665
+ mean_pred = (out["pred_xstart"] * th.sqrt(alpha_bar_next) +
666
+ th.sqrt(1 - alpha_bar_next) * eps)
667
+
668
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
669
+
670
+ def ddim_reverse_sample_loop(
671
+ self,
672
+ model: Model,
673
+ x,
674
+ clip_denoised=True,
675
+ denoised_fn=None,
676
+ model_kwargs=None,
677
+ eta=0.0,
678
+ device=None,
679
+ ):
680
+ if device is None:
681
+ device = next(model.parameters()).device
682
+ sample_t = []
683
+ xstart_t = []
684
+ T = []
685
+ indices = list(range(self.num_timesteps))
686
+ sample = x
687
+ for i in indices:
688
+ t = th.tensor([i] * len(sample), device=device)
689
+ with th.no_grad():
690
+ out = self.ddim_reverse_sample(model,
691
+ sample,
692
+ t=t,
693
+ clip_denoised=clip_denoised,
694
+ denoised_fn=denoised_fn,
695
+ model_kwargs=model_kwargs,
696
+ eta=eta)
697
+ sample = out['sample']
698
+ # [1, ..., T]
699
+ sample_t.append(sample)
700
+ # [0, ...., T-1]
701
+ xstart_t.append(out['pred_xstart'])
702
+ # [0, ..., T-1] ready to use
703
+ T.append(t)
704
+
705
+ return {
706
+ # xT "
707
+ 'sample': sample,
708
+ # (1, ..., T)
709
+ 'sample_t': sample_t,
710
+ # xstart here is a bit different from sampling from T = T-1 to T = 0
711
+ # may not be exact
712
+ 'xstart_t': xstart_t,
713
+ 'T': T,
714
+ }
715
+
716
+ def ddim_sample_loop(
717
+ self,
718
+ model: Model,
719
+ shape=None,
720
+ noise=None,
721
+ clip_denoised=True,
722
+ denoised_fn=None,
723
+ cond_fn=None,
724
+ model_kwargs=None,
725
+ device=None,
726
+ progress=False,
727
+ eta=0.0,
728
+ ):
729
+ """
730
+ Generate samples from the model using DDIM.
731
+
732
+ Same usage as p_sample_loop().
733
+ """
734
+ final = None
735
+ for sample in self.ddim_sample_loop_progressive(
736
+ model,
737
+ shape,
738
+ noise=noise,
739
+ clip_denoised=clip_denoised,
740
+ denoised_fn=denoised_fn,
741
+ cond_fn=cond_fn,
742
+ model_kwargs=model_kwargs,
743
+ device=device,
744
+ progress=progress,
745
+ eta=eta,
746
+ ):
747
+ final = sample
748
+ return final["sample"]
749
+
750
+ def ddim_sample_loop_progressive(
751
+ self,
752
+ model: Model,
753
+ shape=None,
754
+ noise=None,
755
+ clip_denoised=True,
756
+ denoised_fn=None,
757
+ cond_fn=None,
758
+ model_kwargs=None,
759
+ device=None,
760
+ progress=False,
761
+ eta=0.0,
762
+ ):
763
+ """
764
+ Use DDIM to sample from the model and yield intermediate samples from
765
+ each timestep of DDIM.
766
+
767
+ Same usage as p_sample_loop_progressive().
768
+ """
769
+ if device is None:
770
+ device = next(model.parameters()).device
771
+ if noise is not None:
772
+ img = noise
773
+ else:
774
+ assert isinstance(shape, (tuple, list))
775
+ img = th.randn(*shape, device=device)
776
+ indices = list(range(self.num_timesteps))[::-1]
777
+
778
+ if progress:
779
+ # Lazy import so that we don't depend on tqdm.
780
+ from tqdm.auto import tqdm
781
+
782
+ indices = tqdm(indices)
783
+
784
+ for i in indices:
785
+
786
+ if isinstance(model_kwargs, list):
787
+ # index dependent model kwargs
788
+ # (T-1, ..., 0)
789
+ _kwargs = model_kwargs[i]
790
+ else:
791
+ _kwargs = model_kwargs
792
+
793
+ t = th.tensor([i] * len(img), device=device)
794
+ with th.no_grad():
795
+ out = self.ddim_sample(
796
+ model,
797
+ img,
798
+ t,
799
+ clip_denoised=clip_denoised,
800
+ denoised_fn=denoised_fn,
801
+ cond_fn=cond_fn,
802
+ model_kwargs=_kwargs,
803
+ eta=eta,
804
+ )
805
+ out['t'] = t
806
+ yield out
807
+ img = out["sample"]
808
+
809
+ def _vb_terms_bpd(self,
810
+ model: Model,
811
+ x_start,
812
+ x_t,
813
+ t,
814
+ clip_denoised=True,
815
+ model_kwargs=None):
816
+ """
817
+ Get a term for the variational lower-bound.
818
+
819
+ The resulting units are bits (rather than nats, as one might expect).
820
+ This allows for comparison to other papers.
821
+
822
+ :return: a dict with the following keys:
823
+ - 'output': a shape [N] tensor of NLLs or KLs.
824
+ - 'pred_xstart': the x_0 predictions.
825
+ """
826
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
827
+ x_start=x_start, x_t=x_t, t=t)
828
+ out = self.p_mean_variance(model,
829
+ x_t,
830
+ t,
831
+ clip_denoised=clip_denoised,
832
+ model_kwargs=model_kwargs)
833
+ kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"],
834
+ out["log_variance"])
835
+ kl = mean_flat(kl) / np.log(2.0)
836
+
837
+ decoder_nll = -discretized_gaussian_log_likelihood(
838
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"])
839
+ assert decoder_nll.shape == x_start.shape
840
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
841
+
842
+ # At the first timestep return the decoder NLL,
843
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
844
+ output = th.where((t == 0), decoder_nll, kl)
845
+ return {
846
+ "output": output,
847
+ "pred_xstart": out["pred_xstart"],
848
+ 'model_forward': out['model_forward'],
849
+ }
850
+
851
+ def _prior_bpd(self, x_start):
852
+ """
853
+ Get the prior KL term for the variational lower-bound, measured in
854
+ bits-per-dim.
855
+
856
+ This term can't be optimized, as it only depends on the encoder.
857
+
858
+ :param x_start: the [N x C x ...] tensor of inputs.
859
+ :return: a batch of [N] KL values (in bits), one per batch element.
860
+ """
861
+ batch_size = x_start.shape[0]
862
+ t = th.tensor([self.num_timesteps - 1] * batch_size,
863
+ device=x_start.device)
864
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
865
+ kl_prior = normal_kl(mean1=qt_mean,
866
+ logvar1=qt_log_variance,
867
+ mean2=0.0,
868
+ logvar2=0.0)
869
+ return mean_flat(kl_prior) / np.log(2.0)
870
+
871
+ def calc_bpd_loop(self,
872
+ model: Model,
873
+ x_start,
874
+ clip_denoised=True,
875
+ model_kwargs=None):
876
+ """
877
+ Compute the entire variational lower-bound, measured in bits-per-dim,
878
+ as well as other related quantities.
879
+
880
+ :param model: the model to evaluate loss on.
881
+ :param x_start: the [N x C x ...] tensor of inputs.
882
+ :param clip_denoised: if True, clip denoised samples.
883
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
884
+ pass to the model. This can be used for conditioning.
885
+
886
+ :return: a dict containing the following keys:
887
+ - total_bpd: the total variational lower-bound, per batch element.
888
+ - prior_bpd: the prior term in the lower-bound.
889
+ - vb: an [N x T] tensor of terms in the lower-bound.
890
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
891
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
892
+ """
893
+ device = x_start.device
894
+ batch_size = x_start.shape[0]
895
+
896
+ vb = []
897
+ xstart_mse = []
898
+ mse = []
899
+ for t in list(range(self.num_timesteps))[::-1]:
900
+ t_batch = th.tensor([t] * batch_size, device=device)
901
+ noise = th.randn_like(x_start)
902
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
903
+ # Calculate VLB term at the current timestep
904
+ with th.no_grad():
905
+ out = self._vb_terms_bpd(
906
+ model,
907
+ x_start=x_start,
908
+ x_t=x_t,
909
+ t=t_batch,
910
+ clip_denoised=clip_denoised,
911
+ model_kwargs=model_kwargs,
912
+ )
913
+ vb.append(out["output"])
914
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start)**2))
915
+ eps = self._predict_eps_from_xstart(x_t, t_batch,
916
+ out["pred_xstart"])
917
+ mse.append(mean_flat((eps - noise)**2))
918
+
919
+ vb = th.stack(vb, dim=1)
920
+ xstart_mse = th.stack(xstart_mse, dim=1)
921
+ mse = th.stack(mse, dim=1)
922
+
923
+ prior_bpd = self._prior_bpd(x_start)
924
+ total_bpd = vb.sum(dim=1) + prior_bpd
925
+ return {
926
+ "total_bpd": total_bpd,
927
+ "prior_bpd": prior_bpd,
928
+ "vb": vb,
929
+ "xstart_mse": xstart_mse,
930
+ "mse": mse,
931
+ }
932
+
933
+
934
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
935
+ """
936
+ Extract values from a 1-D numpy array for a batch of indices.
937
+
938
+ :param arr: the 1-D numpy array.
939
+ :param timesteps: a tensor of indices into the array to extract.
940
+ :param broadcast_shape: a larger shape of K dimensions with the batch
941
+ dimension equal to the length of timesteps.
942
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
943
+ """
944
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
945
+ while len(res.shape) < len(broadcast_shape):
946
+ res = res[..., None]
947
+ return res.expand(broadcast_shape)
948
+
949
+
950
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
951
+ """
952
+ Get a pre-defined beta schedule for the given name.
953
+
954
+ The beta schedule library consists of beta schedules which remain similar
955
+ in the limit of num_diffusion_timesteps.
956
+ Beta schedules may be added, but should not be removed or changed once
957
+ they are committed to maintain backwards compatibility.
958
+ """
959
+ if schedule_name == "linear":
960
+ # Linear schedule from Ho et al, extended to work for any number of
961
+ # diffusion steps.
962
+ scale = 1000 / num_diffusion_timesteps
963
+ beta_start = scale * 0.0001
964
+ beta_end = scale * 0.02
965
+ return np.linspace(beta_start,
966
+ beta_end,
967
+ num_diffusion_timesteps,
968
+ dtype=np.float64)
969
+ elif schedule_name == "cosine":
970
+ return betas_for_alpha_bar(
971
+ num_diffusion_timesteps,
972
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2)**2,
973
+ )
974
+ elif schedule_name == "const0.01":
975
+ scale = 1000 / num_diffusion_timesteps
976
+ return np.array([scale * 0.01] * num_diffusion_timesteps,
977
+ dtype=np.float64)
978
+ elif schedule_name == "const0.015":
979
+ scale = 1000 / num_diffusion_timesteps
980
+ return np.array([scale * 0.015] * num_diffusion_timesteps,
981
+ dtype=np.float64)
982
+ elif schedule_name == "const0.008":
983
+ scale = 1000 / num_diffusion_timesteps
984
+ return np.array([scale * 0.008] * num_diffusion_timesteps,
985
+ dtype=np.float64)
986
+ elif schedule_name == "const0.0065":
987
+ scale = 1000 / num_diffusion_timesteps
988
+ return np.array([scale * 0.0065] * num_diffusion_timesteps,
989
+ dtype=np.float64)
990
+ elif schedule_name == "const0.0055":
991
+ scale = 1000 / num_diffusion_timesteps
992
+ return np.array([scale * 0.0055] * num_diffusion_timesteps,
993
+ dtype=np.float64)
994
+ elif schedule_name == "const0.0045":
995
+ scale = 1000 / num_diffusion_timesteps
996
+ return np.array([scale * 0.0045] * num_diffusion_timesteps,
997
+ dtype=np.float64)
998
+ elif schedule_name == "const0.0035":
999
+ scale = 1000 / num_diffusion_timesteps
1000
+ return np.array([scale * 0.0035] * num_diffusion_timesteps,
1001
+ dtype=np.float64)
1002
+ elif schedule_name == "const0.0025":
1003
+ scale = 1000 / num_diffusion_timesteps
1004
+ return np.array([scale * 0.0025] * num_diffusion_timesteps,
1005
+ dtype=np.float64)
1006
+ elif schedule_name == "const0.0015":
1007
+ scale = 1000 / num_diffusion_timesteps
1008
+ return np.array([scale * 0.0015] * num_diffusion_timesteps,
1009
+ dtype=np.float64)
1010
+ else:
1011
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
1012
+
1013
+
1014
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
1015
+ """
1016
+ Create a beta schedule that discretizes the given alpha_t_bar function,
1017
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
1018
+
1019
+ :param num_diffusion_timesteps: the number of betas to produce.
1020
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
1021
+ produces the cumulative product of (1-beta) up to that
1022
+ part of the diffusion process.
1023
+ :param max_beta: the maximum beta to use; use values lower than 1 to
1024
+ prevent singularities.
1025
+ """
1026
+ betas = []
1027
+ for i in range(num_diffusion_timesteps):
1028
+ t1 = i / num_diffusion_timesteps
1029
+ t2 = (i + 1) / num_diffusion_timesteps
1030
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
1031
+ return np.array(betas)
1032
+
1033
+
1034
+ def normal_kl(mean1, logvar1, mean2, logvar2):
1035
+ """
1036
+ Compute the KL divergence between two gaussians.
1037
+
1038
+ Shapes are automatically broadcasted, so batches can be compared to
1039
+ scalars, among other use cases.
1040
+ """
1041
+ tensor = None
1042
+ for obj in (mean1, logvar1, mean2, logvar2):
1043
+ if isinstance(obj, th.Tensor):
1044
+ tensor = obj
1045
+ break
1046
+ assert tensor is not None, "at least one argument must be a Tensor"
1047
+
1048
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
1049
+ # Tensors, but it does not work for th.exp().
1050
+ logvar1, logvar2 = [
1051
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
1052
+ for x in (logvar1, logvar2)
1053
+ ]
1054
+
1055
+ return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) +
1056
+ ((mean1 - mean2)**2) * th.exp(-logvar2))
1057
+
1058
+
1059
+ def approx_standard_normal_cdf(x):
1060
+ """
1061
+ A fast approximation of the cumulative distribution function of the
1062
+ standard normal.
1063
+ """
1064
+ return 0.5 * (
1065
+ 1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
1066
+
1067
+
1068
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
1069
+ """
1070
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
1071
+ given image.
1072
+
1073
+ :param x: the target images. It is assumed that this was uint8 values,
1074
+ rescaled to the range [-1, 1].
1075
+ :param means: the Gaussian mean Tensor.
1076
+ :param log_scales: the Gaussian log stddev Tensor.
1077
+ :return: a tensor like x of log probabilities (in nats).
1078
+ """
1079
+ assert x.shape == means.shape == log_scales.shape
1080
+ centered_x = x - means
1081
+ inv_stdv = th.exp(-log_scales)
1082
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
1083
+ cdf_plus = approx_standard_normal_cdf(plus_in)
1084
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
1085
+ cdf_min = approx_standard_normal_cdf(min_in)
1086
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
1087
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
1088
+ cdf_delta = cdf_plus - cdf_min
1089
+ log_probs = th.where(
1090
+ x < -0.999,
1091
+ log_cdf_plus,
1092
+ th.where(x > 0.999, log_one_minus_cdf_min,
1093
+ th.log(cdf_delta.clamp(min=1e-12))),
1094
+ )
1095
+ assert log_probs.shape == x.shape
1096
+ return log_probs
1097
+
1098
+
1099
+ class DummyModel(th.nn.Module):
1100
+ def __init__(self, pred):
1101
+ super().__init__()
1102
+ self.pred = pred
1103
+
1104
+ def forward(self, *args, **kwargs):
1105
+ return DummyReturn(pred=self.pred)
1106
+
1107
+
1108
+ class DummyReturn(NamedTuple):
1109
+ pred: th.Tensor
diffusion/diffusion.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import *
2
+ from dataclasses import dataclass
3
+
4
+
5
+ def space_timesteps(num_timesteps, section_counts):
6
+ """
7
+ Create a list of timesteps to use from an original diffusion process,
8
+ given the number of timesteps we want to take from equally-sized portions
9
+ of the original process.
10
+
11
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
12
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
13
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
14
+
15
+ If the stride is a string starting with "ddim", then the fixed striding
16
+ from the DDIM paper is used, and only one section is allowed.
17
+
18
+ :param num_timesteps: the number of diffusion steps in the original
19
+ process to divide up.
20
+ :param section_counts: either a list of numbers, or a string containing
21
+ comma-separated numbers, indicating the step count
22
+ per section. As a special case, use "ddimN" where N
23
+ is a number of steps to use the striding from the
24
+ DDIM paper.
25
+ :return: a set of diffusion steps from the original process to use.
26
+ """
27
+ if isinstance(section_counts, str):
28
+ if section_counts.startswith("ddim"):
29
+ desired_count = int(section_counts[len("ddim"):])
30
+ for i in range(1, num_timesteps):
31
+ if len(range(0, num_timesteps, i)) == desired_count:
32
+ return set(range(0, num_timesteps, i))
33
+ raise ValueError(
34
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
35
+ )
36
+ section_counts = [int(x) for x in section_counts.split(",")]
37
+ size_per = num_timesteps // len(section_counts)
38
+ extra = num_timesteps % len(section_counts)
39
+ start_idx = 0
40
+ all_steps = []
41
+ for i, section_count in enumerate(section_counts):
42
+ size = size_per + (1 if i < extra else 0)
43
+ if size < section_count:
44
+ raise ValueError(
45
+ f"cannot divide section of {size} steps into {section_count}")
46
+ if section_count <= 1:
47
+ frac_stride = 1
48
+ else:
49
+ frac_stride = (size - 1) / (section_count - 1)
50
+ cur_idx = 0.0
51
+ taken_steps = []
52
+ for _ in range(section_count):
53
+ taken_steps.append(start_idx + round(cur_idx))
54
+ cur_idx += frac_stride
55
+ all_steps += taken_steps
56
+ start_idx += size
57
+ return set(all_steps)
58
+
59
+
60
+ @dataclass
61
+ class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig):
62
+ use_timesteps: Tuple[int] = None
63
+
64
+ def make_sampler(self):
65
+ return SpacedDiffusionBeatGans(self)
66
+
67
+
68
+ class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans):
69
+ """
70
+ A diffusion process which can skip steps in a base diffusion process.
71
+
72
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
73
+ original diffusion process to retain.
74
+ :param kwargs: the kwargs to create the base diffusion process.
75
+ """
76
+ def __init__(self, conf: SpacedDiffusionBeatGansConfig):
77
+ self.conf = conf
78
+ self.use_timesteps = set(conf.use_timesteps)
79
+ # how the new t's mapped to the old t's
80
+ self.timestep_map = []
81
+ self.original_num_steps = len(conf.betas)
82
+
83
+ base_diffusion = GaussianDiffusionBeatGans(conf) # pylint: disable=missing-kwoa
84
+ last_alpha_cumprod = 1.0
85
+ new_betas = []
86
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
87
+ if i in self.use_timesteps:
88
+ # getting the new betas of the new timesteps
89
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
90
+ last_alpha_cumprod = alpha_cumprod
91
+ self.timestep_map.append(i)
92
+ conf.betas = np.array(new_betas)
93
+ super().__init__(conf)
94
+
95
+ def p_mean_variance(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
96
+ return super().p_mean_variance(self._wrap_model(model), *args,
97
+ **kwargs)
98
+
99
+ def training_losses(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
100
+ return super().training_losses(self._wrap_model(model), *args,
101
+ **kwargs)
102
+
103
+ def condition_mean(self, cond_fn, *args, **kwargs):
104
+ return super().condition_mean(self._wrap_model(cond_fn), *args,
105
+ **kwargs)
106
+
107
+ def condition_score(self, cond_fn, *args, **kwargs):
108
+ return super().condition_score(self._wrap_model(cond_fn), *args,
109
+ **kwargs)
110
+
111
+ def _wrap_model(self, model: Model):
112
+ if isinstance(model, _WrappedModel):
113
+ return model
114
+ return _WrappedModel(model, self.timestep_map, self.rescale_timesteps,
115
+ self.original_num_steps)
116
+
117
+ def _scale_timesteps(self, t):
118
+ # Scaling is done by the wrapped model.
119
+ return t
120
+
121
+
122
+ class _WrappedModel:
123
+ """
124
+ converting the supplied t's to the old t's scales.
125
+ """
126
+ def __init__(self, model, timestep_map, rescale_timesteps,
127
+ original_num_steps):
128
+ self.model = model
129
+ self.timestep_map = timestep_map
130
+ self.rescale_timesteps = rescale_timesteps
131
+ self.original_num_steps = original_num_steps
132
+
133
+ def forward(self, x, t, t_cond=None, **kwargs):
134
+ """
135
+ Args:
136
+ t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's
137
+ t_cond: the same as t but can be of different values
138
+ """
139
+ map_tensor = th.tensor(self.timestep_map,
140
+ device=t.device,
141
+ dtype=t.dtype)
142
+
143
+ def do(t):
144
+ new_ts = map_tensor[t]
145
+ if self.rescale_timesteps:
146
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
147
+ return new_ts
148
+
149
+ if t_cond is not None:
150
+ # support t_cond
151
+ t_cond = do(t_cond)
152
+
153
+ return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs)
154
+
155
+ def __getattr__(self, name):
156
+ # allow for calling the model's methods
157
+ if hasattr(self.model, name):
158
+ func = getattr(self.model, name)
159
+ return func
160
+ raise AttributeError(name)
diffusion/resample.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+ import torch as th
5
+ import torch.distributed as dist
6
+
7
+
8
+ def create_named_schedule_sampler(name, diffusion):
9
+ """
10
+ Create a ScheduleSampler from a library of pre-defined samplers.
11
+
12
+ :param name: the name of the sampler.
13
+ :param diffusion: the diffusion object to sample for.
14
+ """
15
+ if name == "uniform":
16
+ return UniformSampler(diffusion)
17
+ else:
18
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
19
+
20
+
21
+ class ScheduleSampler(ABC):
22
+ """
23
+ A distribution over timesteps in the diffusion process, intended to reduce
24
+ variance of the objective.
25
+
26
+ By default, samplers perform unbiased importance sampling, in which the
27
+ objective's mean is unchanged.
28
+ However, subclasses may override sample() to change how the resampled
29
+ terms are reweighted, allowing for actual changes in the objective.
30
+ """
31
+ @abstractmethod
32
+ def weights(self):
33
+ """
34
+ Get a numpy array of weights, one per diffusion step.
35
+
36
+ The weights needn't be normalized, but must be positive.
37
+ """
38
+
39
+ def sample(self, batch_size, device):
40
+ """
41
+ Importance-sample timesteps for a batch.
42
+
43
+ :param batch_size: the number of timesteps.
44
+ :param device: the torch device to save to.
45
+ :return: a tuple (timesteps, weights):
46
+ - timesteps: a tensor of timestep indices.
47
+ - weights: a tensor of weights to scale the resulting losses.
48
+ """
49
+ w = self.weights()
50
+ p = w / np.sum(w)
51
+ indices_np = np.random.choice(len(p), size=(batch_size, ), p=p)
52
+ indices = th.from_numpy(indices_np).long().to(device)
53
+ weights_np = 1 / (len(p) * p[indices_np])
54
+ weights = th.from_numpy(weights_np).float().to(device)
55
+ return indices, weights
56
+
57
+
58
+ class UniformSampler(ScheduleSampler):
59
+ def __init__(self, num_timesteps):
60
+ self._weights = np.ones([num_timesteps])
61
+
62
+ def weights(self):
63
+ return self._weights
dist_utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from torch import distributed
3
+
4
+
5
+ def barrier():
6
+ if distributed.is_initialized():
7
+ distributed.barrier()
8
+ else:
9
+ pass
10
+
11
+
12
+ def broadcast(data, src):
13
+ if distributed.is_initialized():
14
+ distributed.broadcast(data, src)
15
+ else:
16
+ pass
17
+
18
+
19
+ def all_gather(data: List, src):
20
+ if distributed.is_initialized():
21
+ distributed.all_gather(data, src)
22
+ else:
23
+ data[0] = src
24
+
25
+
26
+ def get_rank():
27
+ if distributed.is_initialized():
28
+ return distributed.get_rank()
29
+ else:
30
+ return 0
31
+
32
+
33
+ def get_world_size():
34
+ if distributed.is_initialized():
35
+ return distributed.get_world_size()
36
+ else:
37
+ return 1
38
+
39
+
40
+ def chunk_size(size, rank, world_size):
41
+ extra = rank < size % world_size
42
+ return size // world_size + extra
evals/ffhq128_autoenc_130M.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
evals/ffhq128_autoenc_latent.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ {"fid_ema_T10_Tlatent10": 20.634624481201172}
experiment.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+ import re
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import pytorch_lightning as pl
9
+ import torch
10
+ from numpy.lib.function_base import flip
11
+ from pytorch_lightning import loggers as pl_loggers
12
+ from pytorch_lightning.callbacks import *
13
+ from torch import nn
14
+ from torch.cuda import amp
15
+ from torch.distributions import Categorical
16
+ from torch.optim.optimizer import Optimizer
17
+ from torch.utils.data.dataset import ConcatDataset, TensorDataset
18
+ from torchvision.utils import make_grid, save_image
19
+
20
+ from config import *
21
+ from dataset import *
22
+ from dist_utils import *
23
+ from lmdb_writer import *
24
+ from metrics import *
25
+ from renderer import *
26
+
27
+
28
+ class LitModel(pl.LightningModule):
29
+ def __init__(self, conf: TrainConfig):
30
+ super().__init__()
31
+ assert conf.train_mode != TrainMode.manipulate
32
+ if conf.seed is not None:
33
+ pl.seed_everything(conf.seed)
34
+
35
+ self.save_hyperparameters(conf.as_dict_jsonable())
36
+
37
+ self.conf = conf
38
+
39
+ self.model = conf.make_model_conf().make_model()
40
+ self.ema_model = copy.deepcopy(self.model)
41
+ self.ema_model.requires_grad_(False)
42
+ self.ema_model.eval()
43
+
44
+ model_size = 0
45
+ for param in self.model.parameters():
46
+ model_size += param.data.nelement()
47
+ print('Model params: %.2f M' % (model_size / 1024 / 1024))
48
+
49
+ self.sampler = conf.make_diffusion_conf().make_sampler()
50
+ self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler()
51
+
52
+ # this is shared for both model and latent
53
+ self.T_sampler = conf.make_T_sampler()
54
+
55
+ if conf.train_mode.use_latent_net():
56
+ self.latent_sampler = conf.make_latent_diffusion_conf(
57
+ ).make_sampler()
58
+ self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf(
59
+ ).make_sampler()
60
+ else:
61
+ self.latent_sampler = None
62
+ self.eval_latent_sampler = None
63
+
64
+ # initial variables for consistent sampling
65
+ self.register_buffer(
66
+ 'x_T',
67
+ torch.randn(conf.sample_size, 3, conf.img_size, conf.img_size))
68
+
69
+ if conf.pretrain is not None:
70
+ print(f'loading pretrain ... {conf.pretrain.name}')
71
+ state = torch.load(conf.pretrain.path, map_location='cpu')
72
+ print('step:', state['global_step'])
73
+ self.load_state_dict(state['state_dict'], strict=False)
74
+
75
+ if conf.latent_infer_path is not None:
76
+ print('loading latent stats ...')
77
+ state = torch.load(conf.latent_infer_path)
78
+ self.conds = state['conds']
79
+ self.register_buffer('conds_mean', state['conds_mean'][None, :])
80
+ self.register_buffer('conds_std', state['conds_std'][None, :])
81
+ else:
82
+ self.conds_mean = None
83
+ self.conds_std = None
84
+
85
+ def normalize(self, cond):
86
+ cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to(
87
+ self.device)
88
+ return cond
89
+
90
+ def denormalize(self, cond):
91
+ cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to(
92
+ self.device)
93
+ return cond
94
+
95
+ def sample(self, N, device, T=None, T_latent=None):
96
+ if T is None:
97
+ sampler = self.eval_sampler
98
+ latent_sampler = self.latent_sampler
99
+ else:
100
+ sampler = self.conf._make_diffusion_conf(T).make_sampler()
101
+ latent_sampler = self.conf._make_latent_diffusion_conf(T_latent).make_sampler()
102
+
103
+ noise = torch.randn(N,
104
+ 3,
105
+ self.conf.img_size,
106
+ self.conf.img_size,
107
+ device=device)
108
+ pred_img = render_uncondition(
109
+ self.conf,
110
+ self.ema_model,
111
+ noise,
112
+ sampler=sampler,
113
+ latent_sampler=latent_sampler,
114
+ conds_mean=self.conds_mean,
115
+ conds_std=self.conds_std,
116
+ )
117
+ pred_img = (pred_img + 1) / 2
118
+ return pred_img
119
+
120
+ def render(self, noise, cond=None, T=None):
121
+ if T is None:
122
+ sampler = self.eval_sampler
123
+ else:
124
+ sampler = self.conf._make_diffusion_conf(T).make_sampler()
125
+
126
+ if cond is not None:
127
+ pred_img = render_condition(self.conf,
128
+ self.ema_model,
129
+ noise,
130
+ sampler=sampler,
131
+ cond=cond)
132
+ else:
133
+ pred_img = render_uncondition(self.conf,
134
+ self.ema_model,
135
+ noise,
136
+ sampler=sampler,
137
+ latent_sampler=None)
138
+ pred_img = (pred_img + 1) / 2
139
+ return pred_img
140
+
141
+ def encode(self, x):
142
+ # TODO:
143
+ assert self.conf.model_type.has_autoenc()
144
+ cond = self.ema_model.encoder.forward(x)
145
+ return cond
146
+
147
+ def encode_stochastic(self, x, cond, T=None):
148
+ if T is None:
149
+ sampler = self.eval_sampler
150
+ else:
151
+ sampler = self.conf._make_diffusion_conf(T).make_sampler()
152
+ out = sampler.ddim_reverse_sample_loop(self.ema_model,
153
+ x,
154
+ model_kwargs={'cond': cond})
155
+ return out['sample']
156
+
157
+ def forward(self, noise=None, x_start=None, ema_model: bool = False):
158
+ with amp.autocast(False):
159
+ if ema_model:
160
+ model = self.ema_model
161
+ else:
162
+ model = self.model
163
+ gen = self.eval_sampler.sample(model=model,
164
+ noise=noise,
165
+ x_start=x_start)
166
+ return gen
167
+
168
+ def setup(self, stage=None) -> None:
169
+ """
170
+ make datasets & seeding each worker separately
171
+ """
172
+ ##############################################
173
+ # NEED TO SET THE SEED SEPARATELY HERE
174
+ if self.conf.seed is not None:
175
+ seed = self.conf.seed * get_world_size() + self.global_rank
176
+ np.random.seed(seed)
177
+ torch.manual_seed(seed)
178
+ torch.cuda.manual_seed(seed)
179
+ print('local seed:', seed)
180
+ ##############################################
181
+
182
+ self.train_data = self.conf.make_dataset()
183
+ print('train data:', len(self.train_data))
184
+ self.val_data = self.train_data
185
+ print('val data:', len(self.val_data))
186
+
187
+ def _train_dataloader(self, drop_last=True):
188
+ """
189
+ really make the dataloader
190
+ """
191
+ # make sure to use the fraction of batch size
192
+ # the batch size is global!
193
+ conf = self.conf.clone()
194
+ conf.batch_size = self.batch_size
195
+
196
+ dataloader = conf.make_loader(self.train_data,
197
+ shuffle=True,
198
+ drop_last=drop_last)
199
+ return dataloader
200
+
201
+ def train_dataloader(self):
202
+ """
203
+ return the dataloader, if diffusion mode => return image dataset
204
+ if latent mode => return the inferred latent dataset
205
+ """
206
+ print('on train dataloader start ...')
207
+ if self.conf.train_mode.require_dataset_infer():
208
+ if self.conds is None:
209
+ # usually we load self.conds from a file
210
+ # so we do not need to do this again!
211
+ self.conds = self.infer_whole_dataset()
212
+ # need to use float32! unless the mean & std will be off!
213
+ # (1, c)
214
+ self.conds_mean.data = self.conds.float().mean(dim=0,
215
+ keepdim=True)
216
+ self.conds_std.data = self.conds.float().std(dim=0,
217
+ keepdim=True)
218
+ print('mean:', self.conds_mean.mean(), 'std:',
219
+ self.conds_std.mean())
220
+
221
+ # return the dataset with pre-calculated conds
222
+ conf = self.conf.clone()
223
+ conf.batch_size = self.batch_size
224
+ data = TensorDataset(self.conds)
225
+ return conf.make_loader(data, shuffle=True)
226
+ else:
227
+ return self._train_dataloader()
228
+
229
+ @property
230
+ def batch_size(self):
231
+ """
232
+ local batch size for each worker
233
+ """
234
+ ws = get_world_size()
235
+ assert self.conf.batch_size % ws == 0
236
+ return self.conf.batch_size // ws
237
+
238
+ @property
239
+ def num_samples(self):
240
+ """
241
+ (global) batch size * iterations
242
+ """
243
+ # batch size here is global!
244
+ # global_step already takes into account the accum batches
245
+ return self.global_step * self.conf.batch_size_effective
246
+
247
+ def is_last_accum(self, batch_idx):
248
+ """
249
+ is it the last gradient accumulation loop?
250
+ used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not
251
+ """
252
+ return (batch_idx + 1) % self.conf.accum_batches == 0
253
+
254
+ def infer_whole_dataset(self,
255
+ with_render=False,
256
+ T_render=None,
257
+ render_save_path=None):
258
+ """
259
+ predicting the latents given images using the encoder
260
+
261
+ Args:
262
+ both_flips: include both original and flipped images; no need, it's not an improvement
263
+ with_render: whether to also render the images corresponding to that latent
264
+ render_save_path: lmdb output for the rendered images
265
+ """
266
+ data = self.conf.make_dataset()
267
+ if isinstance(data, CelebAlmdb) and data.crop_d2c:
268
+ # special case where we need the d2c crop
269
+ data.transform = make_transform(self.conf.img_size,
270
+ flip_prob=0,
271
+ crop_d2c=True)
272
+ else:
273
+ data.transform = make_transform(self.conf.img_size, flip_prob=0)
274
+
275
+ # data = SubsetDataset(data, 21)
276
+
277
+ loader = self.conf.make_loader(
278
+ data,
279
+ shuffle=False,
280
+ drop_last=False,
281
+ batch_size=self.conf.batch_size_eval,
282
+ parallel=True,
283
+ )
284
+ model = self.ema_model
285
+ model.eval()
286
+ conds = []
287
+
288
+ if with_render:
289
+ sampler = self.conf._make_diffusion_conf(
290
+ T=T_render or self.conf.T_eval).make_sampler()
291
+
292
+ if self.global_rank == 0:
293
+ writer = LMDBImageWriter(render_save_path,
294
+ format='webp',
295
+ quality=100)
296
+ else:
297
+ writer = nullcontext()
298
+ else:
299
+ writer = nullcontext()
300
+
301
+ with writer:
302
+ for batch in tqdm(loader, total=len(loader), desc='infer'):
303
+ with torch.no_grad():
304
+ # (n, c)
305
+ # print('idx:', batch['index'])
306
+ cond = model.encoder(batch['img'].to(self.device))
307
+
308
+ # used for reordering to match the original dataset
309
+ idx = batch['index']
310
+ idx = self.all_gather(idx)
311
+ if idx.dim() == 2:
312
+ idx = idx.flatten(0, 1)
313
+ argsort = idx.argsort()
314
+
315
+ if with_render:
316
+ noise = torch.randn(len(cond),
317
+ 3,
318
+ self.conf.img_size,
319
+ self.conf.img_size,
320
+ device=self.device)
321
+ render = sampler.sample(model, noise=noise, cond=cond)
322
+ render = (render + 1) / 2
323
+ # print('render:', render.shape)
324
+ # (k, n, c, h, w)
325
+ render = self.all_gather(render)
326
+ if render.dim() == 5:
327
+ # (k*n, c)
328
+ render = render.flatten(0, 1)
329
+
330
+ # print('global_rank:', self.global_rank)
331
+
332
+ if self.global_rank == 0:
333
+ writer.put_images(render[argsort])
334
+
335
+ # (k, n, c)
336
+ cond = self.all_gather(cond)
337
+
338
+ if cond.dim() == 3:
339
+ # (k*n, c)
340
+ cond = cond.flatten(0, 1)
341
+
342
+ conds.append(cond[argsort].cpu())
343
+ # break
344
+ model.train()
345
+ # (N, c) cpu
346
+
347
+ conds = torch.cat(conds).float()
348
+ return conds
349
+
350
+ def training_step(self, batch, batch_idx):
351
+ """
352
+ given an input, calculate the loss function
353
+ no optimization at this stage.
354
+ """
355
+ with amp.autocast(False):
356
+ # batch size here is local!
357
+ # forward
358
+ if self.conf.train_mode.require_dataset_infer():
359
+ # this mode as pre-calculated cond
360
+ cond = batch[0]
361
+ if self.conf.latent_znormalize:
362
+ cond = (cond - self.conds_mean.to(
363
+ self.device)) / self.conds_std.to(self.device)
364
+ else:
365
+ imgs, idxs = batch['img'], batch['index']
366
+ # print(f'(rank {self.global_rank}) batch size:', len(imgs))
367
+ x_start = imgs
368
+
369
+ if self.conf.train_mode == TrainMode.diffusion:
370
+ """
371
+ main training mode!!!
372
+ """
373
+ # with numpy seed we have the problem that the sample t's are related!
374
+ t, weight = self.T_sampler.sample(len(x_start), x_start.device)
375
+ losses = self.sampler.training_losses(model=self.model,
376
+ x_start=x_start,
377
+ t=t)
378
+ elif self.conf.train_mode.is_latent_diffusion():
379
+ """
380
+ training the latent variables!
381
+ """
382
+ # diffusion on the latent
383
+ t, weight = self.T_sampler.sample(len(cond), cond.device)
384
+ latent_losses = self.latent_sampler.training_losses(
385
+ model=self.model.latent_net, x_start=cond, t=t)
386
+ # train only do the latent diffusion
387
+ losses = {
388
+ 'latent': latent_losses['loss'],
389
+ 'loss': latent_losses['loss']
390
+ }
391
+ else:
392
+ raise NotImplementedError()
393
+
394
+ loss = losses['loss'].mean()
395
+ # divide by accum batches to make the accumulated gradient exact!
396
+ for key in ['loss', 'vae', 'latent', 'mmd', 'chamfer', 'arg_cnt']:
397
+ if key in losses:
398
+ losses[key] = self.all_gather(losses[key]).mean()
399
+
400
+ if self.global_rank == 0:
401
+ self.logger.experiment.add_scalar('loss', losses['loss'],
402
+ self.num_samples)
403
+ for key in ['vae', 'latent', 'mmd', 'chamfer', 'arg_cnt']:
404
+ if key in losses:
405
+ self.logger.experiment.add_scalar(
406
+ f'loss/{key}', losses[key], self.num_samples)
407
+
408
+ return {'loss': loss}
409
+
410
+ def on_train_batch_end(self, outputs, batch, batch_idx: int,
411
+ dataloader_idx: int) -> None:
412
+ """
413
+ after each training step ...
414
+ """
415
+ if self.is_last_accum(batch_idx):
416
+ # only apply ema on the last gradient accumulation step,
417
+ # if it is the iteration that has optimizer.step()
418
+ if self.conf.train_mode == TrainMode.latent_diffusion:
419
+ # it trains only the latent hence change only the latent
420
+ ema(self.model.latent_net, self.ema_model.latent_net,
421
+ self.conf.ema_decay)
422
+ else:
423
+ ema(self.model, self.ema_model, self.conf.ema_decay)
424
+
425
+ # logging
426
+ if self.conf.train_mode.require_dataset_infer():
427
+ imgs = None
428
+ else:
429
+ imgs = batch['img']
430
+ self.log_sample(x_start=imgs)
431
+ self.evaluate_scores()
432
+
433
+ def on_before_optimizer_step(self, optimizer: Optimizer,
434
+ optimizer_idx: int) -> None:
435
+ # fix the fp16 + clip grad norm problem with pytorch lightinng
436
+ # this is the currently correct way to do it
437
+ if self.conf.grad_clip > 0:
438
+ # from trainer.params_grads import grads_norm, iter_opt_params
439
+ params = [
440
+ p for group in optimizer.param_groups for p in group['params']
441
+ ]
442
+ # print('before:', grads_norm(iter_opt_params(optimizer)))
443
+ torch.nn.utils.clip_grad_norm_(params,
444
+ max_norm=self.conf.grad_clip)
445
+ # print('after:', grads_norm(iter_opt_params(optimizer)))
446
+
447
+ def log_sample(self, x_start):
448
+ """
449
+ put images to the tensorboard
450
+ """
451
+ def do(model,
452
+ postfix,
453
+ use_xstart,
454
+ save_real=False,
455
+ no_latent_diff=False,
456
+ interpolate=False):
457
+ model.eval()
458
+ with torch.no_grad():
459
+ all_x_T = self.split_tensor(self.x_T)
460
+ batch_size = min(len(all_x_T), self.conf.batch_size_eval)
461
+ # allow for superlarge models
462
+ loader = DataLoader(all_x_T, batch_size=batch_size)
463
+
464
+ Gen = []
465
+ for x_T in loader:
466
+ if use_xstart:
467
+ _xstart = x_start[:len(x_T)]
468
+ else:
469
+ _xstart = None
470
+
471
+ if self.conf.train_mode.is_latent_diffusion(
472
+ ) and not use_xstart:
473
+ # diffusion of the latent first
474
+ gen = render_uncondition(
475
+ conf=self.conf,
476
+ model=model,
477
+ x_T=x_T,
478
+ sampler=self.eval_sampler,
479
+ latent_sampler=self.eval_latent_sampler,
480
+ conds_mean=self.conds_mean,
481
+ conds_std=self.conds_std)
482
+ else:
483
+ if not use_xstart and self.conf.model_type.has_noise_to_cond(
484
+ ):
485
+ model: BeatGANsAutoencModel
486
+ # special case, it may not be stochastic, yet can sample
487
+ cond = torch.randn(len(x_T),
488
+ self.conf.style_ch,
489
+ device=self.device)
490
+ cond = model.noise_to_cond(cond)
491
+ else:
492
+ if interpolate:
493
+ with amp.autocast(self.conf.fp16):
494
+ cond = model.encoder(_xstart)
495
+ i = torch.randperm(len(cond))
496
+ cond = (cond + cond[i]) / 2
497
+ else:
498
+ cond = None
499
+ gen = self.eval_sampler.sample(model=model,
500
+ noise=x_T,
501
+ cond=cond,
502
+ x_start=_xstart)
503
+ Gen.append(gen)
504
+
505
+ gen = torch.cat(Gen)
506
+ gen = self.all_gather(gen)
507
+ if gen.dim() == 5:
508
+ # (n, c, h, w)
509
+ gen = gen.flatten(0, 1)
510
+
511
+ if save_real and use_xstart:
512
+ # save the original images to the tensorboard
513
+ real = self.all_gather(_xstart)
514
+ if real.dim() == 5:
515
+ real = real.flatten(0, 1)
516
+
517
+ if self.global_rank == 0:
518
+ grid_real = (make_grid(real) + 1) / 2
519
+ self.logger.experiment.add_image(
520
+ f'sample{postfix}/real', grid_real,
521
+ self.num_samples)
522
+
523
+ if self.global_rank == 0:
524
+ # save samples to the tensorboard
525
+ grid = (make_grid(gen) + 1) / 2
526
+ sample_dir = os.path.join(self.conf.logdir,
527
+ f'sample{postfix}')
528
+ if not os.path.exists(sample_dir):
529
+ os.makedirs(sample_dir)
530
+ path = os.path.join(sample_dir,
531
+ '%d.png' % self.num_samples)
532
+ save_image(grid, path)
533
+ self.logger.experiment.add_image(f'sample{postfix}', grid,
534
+ self.num_samples)
535
+ model.train()
536
+
537
+ if self.conf.sample_every_samples > 0 and is_time(
538
+ self.num_samples, self.conf.sample_every_samples,
539
+ self.conf.batch_size_effective):
540
+
541
+ if self.conf.train_mode.require_dataset_infer():
542
+ do(self.model, '', use_xstart=False)
543
+ do(self.ema_model, '_ema', use_xstart=False)
544
+ else:
545
+ if self.conf.model_type.has_autoenc(
546
+ ) and self.conf.model_type.can_sample():
547
+ do(self.model, '', use_xstart=False)
548
+ do(self.ema_model, '_ema', use_xstart=False)
549
+ # autoencoding mode
550
+ do(self.model, '_enc', use_xstart=True, save_real=True)
551
+ do(self.ema_model,
552
+ '_enc_ema',
553
+ use_xstart=True,
554
+ save_real=True)
555
+ elif self.conf.train_mode.use_latent_net():
556
+ do(self.model, '', use_xstart=False)
557
+ do(self.ema_model, '_ema', use_xstart=False)
558
+ # autoencoding mode
559
+ do(self.model, '_enc', use_xstart=True, save_real=True)
560
+ do(self.model,
561
+ '_enc_nodiff',
562
+ use_xstart=True,
563
+ save_real=True,
564
+ no_latent_diff=True)
565
+ do(self.ema_model,
566
+ '_enc_ema',
567
+ use_xstart=True,
568
+ save_real=True)
569
+ else:
570
+ do(self.model, '', use_xstart=True, save_real=True)
571
+ do(self.ema_model, '_ema', use_xstart=True, save_real=True)
572
+
573
+ def evaluate_scores(self):
574
+ """
575
+ evaluate FID and other scores during training (put to the tensorboard)
576
+ For, FID. It is a fast version with 5k images (gold standard is 50k).
577
+ Don't use its results in the paper!
578
+ """
579
+ def fid(model, postfix):
580
+ score = evaluate_fid(self.eval_sampler,
581
+ model,
582
+ self.conf,
583
+ device=self.device,
584
+ train_data=self.train_data,
585
+ val_data=self.val_data,
586
+ latent_sampler=self.eval_latent_sampler,
587
+ conds_mean=self.conds_mean,
588
+ conds_std=self.conds_std)
589
+ if self.global_rank == 0:
590
+ self.logger.experiment.add_scalar(f'FID{postfix}', score,
591
+ self.num_samples)
592
+ if not os.path.exists(self.conf.logdir):
593
+ os.makedirs(self.conf.logdir)
594
+ with open(os.path.join(self.conf.logdir, 'eval.txt'),
595
+ 'a') as f:
596
+ metrics = {
597
+ f'FID{postfix}': score,
598
+ 'num_samples': self.num_samples,
599
+ }
600
+ f.write(json.dumps(metrics) + "\n")
601
+
602
+ def lpips(model, postfix):
603
+ if self.conf.model_type.has_autoenc(
604
+ ) and self.conf.train_mode.is_autoenc():
605
+ # {'lpips', 'ssim', 'mse'}
606
+ score = evaluate_lpips(self.eval_sampler,
607
+ model,
608
+ self.conf,
609
+ device=self.device,
610
+ val_data=self.val_data,
611
+ latent_sampler=self.eval_latent_sampler)
612
+
613
+ if self.global_rank == 0:
614
+ for key, val in score.items():
615
+ self.logger.experiment.add_scalar(
616
+ f'{key}{postfix}', val, self.num_samples)
617
+
618
+ if self.conf.eval_every_samples > 0 and self.num_samples > 0 and is_time(
619
+ self.num_samples, self.conf.eval_every_samples,
620
+ self.conf.batch_size_effective):
621
+ print(f'eval fid @ {self.num_samples}')
622
+ lpips(self.model, '')
623
+ fid(self.model, '')
624
+
625
+ if self.conf.eval_ema_every_samples > 0 and self.num_samples > 0 and is_time(
626
+ self.num_samples, self.conf.eval_ema_every_samples,
627
+ self.conf.batch_size_effective):
628
+ print(f'eval fid ema @ {self.num_samples}')
629
+ fid(self.ema_model, '_ema')
630
+ # it's too slow
631
+ # lpips(self.ema_model, '_ema')
632
+
633
+ def configure_optimizers(self):
634
+ out = {}
635
+ if self.conf.optimizer == OptimizerType.adam:
636
+ optim = torch.optim.Adam(self.model.parameters(),
637
+ lr=self.conf.lr,
638
+ weight_decay=self.conf.weight_decay)
639
+ elif self.conf.optimizer == OptimizerType.adamw:
640
+ optim = torch.optim.AdamW(self.model.parameters(),
641
+ lr=self.conf.lr,
642
+ weight_decay=self.conf.weight_decay)
643
+ else:
644
+ raise NotImplementedError()
645
+ out['optimizer'] = optim
646
+ if self.conf.warmup > 0:
647
+ sched = torch.optim.lr_scheduler.LambdaLR(optim,
648
+ lr_lambda=WarmupLR(
649
+ self.conf.warmup))
650
+ out['lr_scheduler'] = {
651
+ 'scheduler': sched,
652
+ 'interval': 'step',
653
+ }
654
+ return out
655
+
656
+ def split_tensor(self, x):
657
+ """
658
+ extract the tensor for a corresponding "worker" in the batch dimension
659
+
660
+ Args:
661
+ x: (n, c)
662
+
663
+ Returns: x: (n_local, c)
664
+ """
665
+ n = len(x)
666
+ rank = self.global_rank
667
+ world_size = get_world_size()
668
+ # print(f'rank: {rank}/{world_size}')
669
+ per_rank = n // world_size
670
+ return x[rank * per_rank:(rank + 1) * per_rank]
671
+
672
+ def test_step(self, batch, *args, **kwargs):
673
+ """
674
+ for the "eval" mode.
675
+ We first select what to do according to the "conf.eval_programs".
676
+ test_step will only run for "one iteration" (it's a hack!).
677
+
678
+ We just want the multi-gpu support.
679
+ """
680
+ # make sure you seed each worker differently!
681
+ self.setup()
682
+
683
+ # it will run only one step!
684
+ print('global step:', self.global_step)
685
+ """
686
+ "infer" = predict the latent variables using the encoder on the whole dataset
687
+ """
688
+ if 'infer' in self.conf.eval_programs:
689
+ if 'infer' in self.conf.eval_programs:
690
+ print('infer ...')
691
+ conds = self.infer_whole_dataset().float()
692
+ # NOTE: always use this path for the latent.pkl files
693
+ save_path = f'checkpoints/{self.conf.name}/latent.pkl'
694
+ else:
695
+ raise NotImplementedError()
696
+
697
+ if self.global_rank == 0:
698
+ conds_mean = conds.mean(dim=0)
699
+ conds_std = conds.std(dim=0)
700
+ if not os.path.exists(os.path.dirname(save_path)):
701
+ os.makedirs(os.path.dirname(save_path))
702
+ torch.save(
703
+ {
704
+ 'conds': conds,
705
+ 'conds_mean': conds_mean,
706
+ 'conds_std': conds_std,
707
+ }, save_path)
708
+ """
709
+ "infer+render" = predict the latent variables using the encoder on the whole dataset
710
+ THIS ALSO GENERATE CORRESPONDING IMAGES
711
+ """
712
+ # infer + reconstruction quality of the input
713
+ for each in self.conf.eval_programs:
714
+ if each.startswith('infer+render'):
715
+ m = re.match(r'infer\+render([0-9]+)', each)
716
+ if m is not None:
717
+ T = int(m[1])
718
+ self.setup()
719
+ print(f'infer + reconstruction T{T} ...')
720
+ conds = self.infer_whole_dataset(
721
+ with_render=True,
722
+ T_render=T,
723
+ render_save_path=
724
+ f'latent_infer_render{T}/{self.conf.name}.lmdb',
725
+ )
726
+ save_path = f'latent_infer_render{T}/{self.conf.name}.pkl'
727
+ conds_mean = conds.mean(dim=0)
728
+ conds_std = conds.std(dim=0)
729
+ if not os.path.exists(os.path.dirname(save_path)):
730
+ os.makedirs(os.path.dirname(save_path))
731
+ torch.save(
732
+ {
733
+ 'conds': conds,
734
+ 'conds_mean': conds_mean,
735
+ 'conds_std': conds_std,
736
+ }, save_path)
737
+
738
+ # evals those "fidXX"
739
+ """
740
+ "fid<T>" = unconditional generation (conf.train_mode = diffusion).
741
+ Note: Diff. autoenc will still receive real images in this mode.
742
+ "fid<T>,<T_latent>" = unconditional generation for latent models (conf.train_mode = latent_diffusion).
743
+ Note: Diff. autoenc will still NOT receive real images in this made.
744
+ but you need to make sure that the train_mode is latent_diffusion.
745
+ """
746
+ for each in self.conf.eval_programs:
747
+ if each.startswith('fid'):
748
+ m = re.match(r'fid\(([0-9]+),([0-9]+)\)', each)
749
+ clip_latent_noise = False
750
+ if m is not None:
751
+ # eval(T1,T2)
752
+ T = int(m[1])
753
+ T_latent = int(m[2])
754
+ print(f'evaluating FID T = {T}... latent T = {T_latent}')
755
+ else:
756
+ m = re.match(r'fidclip\(([0-9]+),([0-9]+)\)', each)
757
+ if m is not None:
758
+ # fidclip(T1,T2)
759
+ T = int(m[1])
760
+ T_latent = int(m[2])
761
+ clip_latent_noise = True
762
+ print(
763
+ f'evaluating FID (clip latent noise) T = {T}... latent T = {T_latent}'
764
+ )
765
+ else:
766
+ # evalT
767
+ _, T = each.split('fid')
768
+ T = int(T)
769
+ T_latent = None
770
+ print(f'evaluating FID T = {T}...')
771
+
772
+ self.train_dataloader()
773
+ sampler = self.conf._make_diffusion_conf(T=T).make_sampler()
774
+ if T_latent is not None:
775
+ latent_sampler = self.conf._make_latent_diffusion_conf(
776
+ T=T_latent).make_sampler()
777
+ else:
778
+ latent_sampler = None
779
+
780
+ conf = self.conf.clone()
781
+ conf.eval_num_images = 50_000
782
+ score = evaluate_fid(
783
+ sampler,
784
+ self.ema_model,
785
+ conf,
786
+ device=self.device,
787
+ train_data=self.train_data,
788
+ val_data=self.val_data,
789
+ latent_sampler=latent_sampler,
790
+ conds_mean=self.conds_mean,
791
+ conds_std=self.conds_std,
792
+ remove_cache=False,
793
+ clip_latent_noise=clip_latent_noise,
794
+ )
795
+ if T_latent is None:
796
+ self.log(f'fid_ema_T{T}', score)
797
+ else:
798
+ name = 'fid'
799
+ if clip_latent_noise:
800
+ name += '_clip'
801
+ name += f'_ema_T{T}_Tlatent{T_latent}'
802
+ self.log(name, score)
803
+ """
804
+ "recon<T>" = reconstruction & autoencoding (without noise inversion)
805
+ """
806
+ for each in self.conf.eval_programs:
807
+ if each.startswith('recon'):
808
+ self.model: BeatGANsAutoencModel
809
+ _, T = each.split('recon')
810
+ T = int(T)
811
+ print(f'evaluating reconstruction T = {T}...')
812
+
813
+ sampler = self.conf._make_diffusion_conf(T=T).make_sampler()
814
+
815
+ conf = self.conf.clone()
816
+ # eval whole val dataset
817
+ conf.eval_num_images = len(self.val_data)
818
+ # {'lpips', 'mse', 'ssim'}
819
+ score = evaluate_lpips(sampler,
820
+ self.ema_model,
821
+ conf,
822
+ device=self.device,
823
+ val_data=self.val_data,
824
+ latent_sampler=None)
825
+ for k, v in score.items():
826
+ self.log(f'{k}_ema_T{T}', v)
827
+ """
828
+ "inv<T>" = reconstruction with noise inversion
829
+ """
830
+ for each in self.conf.eval_programs:
831
+ if each.startswith('inv'):
832
+ self.model: BeatGANsAutoencModel
833
+ _, T = each.split('inv')
834
+ T = int(T)
835
+ print(
836
+ f'evaluating reconstruction with noise inversion T = {T}...'
837
+ )
838
+
839
+ sampler = self.conf._make_diffusion_conf(T=T).make_sampler()
840
+
841
+ conf = self.conf.clone()
842
+ # eval whole val dataset
843
+ conf.eval_num_images = len(self.val_data)
844
+ # {'lpips', 'mse', 'ssim'}
845
+ score = evaluate_lpips(sampler,
846
+ self.ema_model,
847
+ conf,
848
+ device=self.device,
849
+ val_data=self.val_data,
850
+ latent_sampler=None,
851
+ use_inverted_noise=True)
852
+ for k, v in score.items():
853
+ self.log(f'{k}_inv_ema_T{T}', v)
854
+
855
+
856
+ def ema(source, target, decay):
857
+ source_dict = source.state_dict()
858
+ target_dict = target.state_dict()
859
+ for key in source_dict.keys():
860
+ target_dict[key].data.copy_(target_dict[key].data * decay +
861
+ source_dict[key].data * (1 - decay))
862
+
863
+
864
+ class WarmupLR:
865
+ def __init__(self, warmup) -> None:
866
+ self.warmup = warmup
867
+
868
+ def __call__(self, step):
869
+ return min(step, self.warmup) / self.warmup
870
+
871
+
872
+ def is_time(num_samples, every, step_size):
873
+ closest = (num_samples // every) * every
874
+ return num_samples - closest < step_size
875
+
876
+
877
+ def train(conf: TrainConfig, gpus, nodes=1, mode: str = 'train'):
878
+ print('conf:', conf.name)
879
+ # assert not (conf.fp16 and conf.grad_clip > 0
880
+ # ), 'pytorch lightning has bug with amp + gradient clipping'
881
+ model = LitModel(conf)
882
+
883
+ if not os.path.exists(conf.logdir):
884
+ os.makedirs(conf.logdir)
885
+ checkpoint = ModelCheckpoint(dirpath=f'{conf.logdir}',
886
+ save_last=True,
887
+ save_top_k=1,
888
+ every_n_train_steps=conf.save_every_samples //
889
+ conf.batch_size_effective)
890
+ checkpoint_path = f'{conf.logdir}/last.ckpt'
891
+ print('ckpt path:', checkpoint_path)
892
+ if os.path.exists(checkpoint_path):
893
+ resume = checkpoint_path
894
+ print('resume!')
895
+ else:
896
+ if conf.continue_from is not None:
897
+ # continue from a checkpoint
898
+ resume = conf.continue_from.path
899
+ else:
900
+ resume = None
901
+
902
+ tb_logger = pl_loggers.TensorBoardLogger(save_dir=conf.logdir,
903
+ name=None,
904
+ version='')
905
+
906
+ # from pytorch_lightning.
907
+
908
+ plugins = []
909
+ if len(gpus) == 1 and nodes == 1:
910
+ accelerator = None
911
+ else:
912
+ accelerator = 'ddp'
913
+ from pytorch_lightning.plugins import DDPPlugin
914
+
915
+ # important for working with gradient checkpoint
916
+ plugins.append(DDPPlugin(find_unused_parameters=False))
917
+
918
+ trainer = pl.Trainer(
919
+ max_steps=conf.total_samples // conf.batch_size_effective,
920
+ resume_from_checkpoint=resume,
921
+ gpus=gpus,
922
+ num_nodes=nodes,
923
+ accelerator=accelerator,
924
+ precision=16 if conf.fp16 else 32,
925
+ callbacks=[
926
+ checkpoint,
927
+ LearningRateMonitor(),
928
+ ],
929
+ # clip in the model instead
930
+ # gradient_clip_val=conf.grad_clip,
931
+ replace_sampler_ddp=True,
932
+ logger=tb_logger,
933
+ accumulate_grad_batches=conf.accum_batches,
934
+ plugins=plugins,
935
+ )
936
+
937
+ if mode == 'train':
938
+ trainer.fit(model)
939
+ elif mode == 'eval':
940
+ # load the latest checkpoint
941
+ # perform lpips
942
+ # dummy loader to allow calling "test_step"
943
+ dummy = DataLoader(TensorDataset(torch.tensor([0.] * conf.batch_size)),
944
+ batch_size=conf.batch_size)
945
+ eval_path = conf.eval_path or checkpoint_path
946
+ # conf.eval_num_images = 50
947
+ print('loading from:', eval_path)
948
+ state = torch.load(eval_path, map_location='cpu')
949
+ print('step:', state['global_step'])
950
+ model.load_state_dict(state['state_dict'])
951
+ # trainer.fit(model)
952
+ out = trainer.test(model, dataloaders=dummy)
953
+ # first (and only) loader
954
+ out = out[0]
955
+ print(out)
956
+
957
+ if get_rank() == 0:
958
+ # save to tensorboard
959
+ for k, v in out.items():
960
+ tb_logger.experiment.add_scalar(
961
+ k, v, state['global_step'] * conf.batch_size_effective)
962
+
963
+ # # save to file
964
+ # # make it a dict of list
965
+ # for k, v in out.items():
966
+ # out[k] = [v]
967
+ tgt = f'evals/{conf.name}.txt'
968
+ dirname = os.path.dirname(tgt)
969
+ if not os.path.exists(dirname):
970
+ os.makedirs(dirname)
971
+ with open(tgt, 'a') as f:
972
+ f.write(json.dumps(out) + "\n")
973
+ # pd.DataFrame(out).to_csv(tgt)
974
+ else:
975
+ raise NotImplementedError()
experiment_classifier.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import *
2
+ from dataset import *
3
+ import pandas as pd
4
+ import json
5
+ import os
6
+ import copy
7
+
8
+ import numpy as np
9
+ import pytorch_lightning as pl
10
+ from pytorch_lightning import loggers as pl_loggers
11
+ from pytorch_lightning.callbacks import *
12
+ import torch
13
+
14
+
15
+ class ZipLoader:
16
+ def __init__(self, loaders):
17
+ self.loaders = loaders
18
+
19
+ def __len__(self):
20
+ return len(self.loaders[0])
21
+
22
+ def __iter__(self):
23
+ for each in zip(*self.loaders):
24
+ yield each
25
+
26
+
27
+ class ClsModel(pl.LightningModule):
28
+ def __init__(self, conf: TrainConfig):
29
+ super().__init__()
30
+ assert conf.train_mode.is_manipulate()
31
+ if conf.seed is not None:
32
+ pl.seed_everything(conf.seed)
33
+
34
+ self.save_hyperparameters(conf.as_dict_jsonable())
35
+ self.conf = conf
36
+
37
+ # preparations
38
+ if conf.train_mode == TrainMode.manipulate:
39
+ # this is only important for training!
40
+ # the latent is freshly inferred to make sure it matches the image
41
+ # manipulating latents require the base model
42
+ self.model = conf.make_model_conf().make_model()
43
+ self.ema_model = copy.deepcopy(self.model)
44
+ self.model.requires_grad_(False)
45
+ self.ema_model.requires_grad_(False)
46
+ self.ema_model.eval()
47
+
48
+ if conf.pretrain is not None:
49
+ print(f'loading pretrain ... {conf.pretrain.name}')
50
+ state = torch.load(conf.pretrain.path, map_location='cpu')
51
+ print('step:', state['global_step'])
52
+ self.load_state_dict(state['state_dict'], strict=False)
53
+
54
+ # load the latent stats
55
+ if conf.manipulate_znormalize:
56
+ print('loading latent stats ...')
57
+ state = torch.load(conf.latent_infer_path)
58
+ self.conds = state['conds']
59
+ self.register_buffer('conds_mean',
60
+ state['conds_mean'][None, :])
61
+ self.register_buffer('conds_std', state['conds_std'][None, :])
62
+ else:
63
+ self.conds_mean = None
64
+ self.conds_std = None
65
+
66
+ if conf.manipulate_mode in [ManipulateMode.celebahq_all]:
67
+ num_cls = len(CelebAttrDataset.id_to_cls)
68
+ elif conf.manipulate_mode.is_single_class():
69
+ num_cls = 1
70
+ else:
71
+ raise NotImplementedError()
72
+
73
+ # classifier
74
+ if conf.train_mode == TrainMode.manipulate:
75
+ # latent manipluation requires only a linear classifier
76
+ self.classifier = nn.Linear(conf.style_ch, num_cls)
77
+ else:
78
+ raise NotImplementedError()
79
+
80
+ self.ema_classifier = copy.deepcopy(self.classifier)
81
+
82
+ def state_dict(self, *args, **kwargs):
83
+ # don't save the base model
84
+ out = {}
85
+ for k, v in super().state_dict(*args, **kwargs).items():
86
+ if k.startswith('model.'):
87
+ pass
88
+ elif k.startswith('ema_model.'):
89
+ pass
90
+ else:
91
+ out[k] = v
92
+ return out
93
+
94
+ def load_state_dict(self, state_dict, strict: bool = None):
95
+ if self.conf.train_mode == TrainMode.manipulate:
96
+ # change the default strict => False
97
+ if strict is None:
98
+ strict = False
99
+ else:
100
+ if strict is None:
101
+ strict = True
102
+ return super().load_state_dict(state_dict, strict=strict)
103
+
104
+ def normalize(self, cond):
105
+ cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to(
106
+ self.device)
107
+ return cond
108
+
109
+ def denormalize(self, cond):
110
+ cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to(
111
+ self.device)
112
+ return cond
113
+
114
+ def load_dataset(self):
115
+ if self.conf.manipulate_mode == ManipulateMode.d2c_fewshot:
116
+ return CelebD2CAttrFewshotDataset(
117
+ cls_name=self.conf.manipulate_cls,
118
+ K=self.conf.manipulate_shots,
119
+ img_folder=data_paths['celeba'],
120
+ img_size=self.conf.img_size,
121
+ seed=self.conf.manipulate_seed,
122
+ all_neg=False,
123
+ do_augment=True,
124
+ )
125
+ elif self.conf.manipulate_mode == ManipulateMode.d2c_fewshot_allneg:
126
+ # positive-unlabeled classifier needs to keep the class ratio 1:1
127
+ # we use two dataloaders, one for each class, to stabiliize the training
128
+ img_folder = data_paths['celeba']
129
+
130
+ return [
131
+ CelebD2CAttrFewshotDataset(
132
+ cls_name=self.conf.manipulate_cls,
133
+ K=self.conf.manipulate_shots,
134
+ img_folder=img_folder,
135
+ img_size=self.conf.img_size,
136
+ only_cls_name=self.conf.manipulate_cls,
137
+ only_cls_value=1,
138
+ seed=self.conf.manipulate_seed,
139
+ all_neg=True,
140
+ do_augment=True),
141
+ CelebD2CAttrFewshotDataset(
142
+ cls_name=self.conf.manipulate_cls,
143
+ K=self.conf.manipulate_shots,
144
+ img_folder=img_folder,
145
+ img_size=self.conf.img_size,
146
+ only_cls_name=self.conf.manipulate_cls,
147
+ only_cls_value=-1,
148
+ seed=self.conf.manipulate_seed,
149
+ all_neg=True,
150
+ do_augment=True),
151
+ ]
152
+ elif self.conf.manipulate_mode == ManipulateMode.celebahq_all:
153
+ return CelebHQAttrDataset(data_paths['celebahq'],
154
+ self.conf.img_size,
155
+ data_paths['celebahq_anno'],
156
+ do_augment=True)
157
+ else:
158
+ raise NotImplementedError()
159
+
160
+ def setup(self, stage=None) -> None:
161
+ ##############################################
162
+ # NEED TO SET THE SEED SEPARATELY HERE
163
+ if self.conf.seed is not None:
164
+ seed = self.conf.seed * get_world_size() + self.global_rank
165
+ np.random.seed(seed)
166
+ torch.manual_seed(seed)
167
+ torch.cuda.manual_seed(seed)
168
+ print('local seed:', seed)
169
+ ##############################################
170
+
171
+ self.train_data = self.load_dataset()
172
+ if self.conf.manipulate_mode.is_fewshot():
173
+ # repeat the dataset to be larger (speed up the training)
174
+ if isinstance(self.train_data, list):
175
+ # fewshot-allneg has two datasets
176
+ # we resize them to be of equal sizes
177
+ a, b = self.train_data
178
+ self.train_data = [
179
+ Repeat(a, max(len(a), len(b))),
180
+ Repeat(b, max(len(a), len(b))),
181
+ ]
182
+ else:
183
+ self.train_data = Repeat(self.train_data, 100_000)
184
+
185
+ def train_dataloader(self):
186
+ # make sure to use the fraction of batch size
187
+ # the batch size is global!
188
+ conf = self.conf.clone()
189
+ conf.batch_size = self.batch_size
190
+ if isinstance(self.train_data, list):
191
+ dataloader = []
192
+ for each in self.train_data:
193
+ dataloader.append(
194
+ conf.make_loader(each, shuffle=True, drop_last=True))
195
+ dataloader = ZipLoader(dataloader)
196
+ else:
197
+ dataloader = conf.make_loader(self.train_data,
198
+ shuffle=True,
199
+ drop_last=True)
200
+ return dataloader
201
+
202
+ @property
203
+ def batch_size(self):
204
+ ws = get_world_size()
205
+ assert self.conf.batch_size % ws == 0
206
+ return self.conf.batch_size // ws
207
+
208
+ def training_step(self, batch, batch_idx):
209
+ self.ema_model: BeatGANsAutoencModel
210
+ if isinstance(batch, tuple):
211
+ a, b = batch
212
+ imgs = torch.cat([a['img'], b['img']])
213
+ labels = torch.cat([a['labels'], b['labels']])
214
+ else:
215
+ imgs = batch['img']
216
+ # print(f'({self.global_rank}) imgs:', imgs.shape)
217
+ labels = batch['labels']
218
+
219
+ if self.conf.train_mode == TrainMode.manipulate:
220
+ self.ema_model.eval()
221
+ with torch.no_grad():
222
+ # (n, c)
223
+ cond = self.ema_model.encoder(imgs)
224
+
225
+ if self.conf.manipulate_znormalize:
226
+ cond = self.normalize(cond)
227
+
228
+ # (n, cls)
229
+ pred = self.classifier.forward(cond)
230
+ pred_ema = self.ema_classifier.forward(cond)
231
+ elif self.conf.train_mode == TrainMode.manipulate_img:
232
+ # (n, cls)
233
+ pred = self.classifier.forward(imgs)
234
+ pred_ema = None
235
+ elif self.conf.train_mode == TrainMode.manipulate_imgt:
236
+ t, weight = self.T_sampler.sample(len(imgs), imgs.device)
237
+ imgs_t = self.sampler.q_sample(imgs, t)
238
+ pred = self.classifier.forward(imgs_t, t=t)
239
+ pred_ema = None
240
+ print('pred:', pred.shape)
241
+ else:
242
+ raise NotImplementedError()
243
+
244
+ if self.conf.manipulate_mode.is_celeba_attr():
245
+ gt = torch.where(labels > 0,
246
+ torch.ones_like(labels).float(),
247
+ torch.zeros_like(labels).float())
248
+ elif self.conf.manipulate_mode == ManipulateMode.relighting:
249
+ gt = labels
250
+ else:
251
+ raise NotImplementedError()
252
+
253
+ if self.conf.manipulate_loss == ManipulateLossType.bce:
254
+ loss = F.binary_cross_entropy_with_logits(pred, gt)
255
+ if pred_ema is not None:
256
+ loss_ema = F.binary_cross_entropy_with_logits(pred_ema, gt)
257
+ elif self.conf.manipulate_loss == ManipulateLossType.mse:
258
+ loss = F.mse_loss(pred, gt)
259
+ if pred_ema is not None:
260
+ loss_ema = F.mse_loss(pred_ema, gt)
261
+ else:
262
+ raise NotImplementedError()
263
+
264
+ self.log('loss', loss)
265
+ self.log('loss_ema', loss_ema)
266
+ return loss
267
+
268
+ def on_train_batch_end(self, outputs, batch, batch_idx: int,
269
+ dataloader_idx: int) -> None:
270
+ ema(self.classifier, self.ema_classifier, self.conf.ema_decay)
271
+
272
+ def configure_optimizers(self):
273
+ optim = torch.optim.Adam(self.classifier.parameters(),
274
+ lr=self.conf.lr,
275
+ weight_decay=self.conf.weight_decay)
276
+ return optim
277
+
278
+
279
+ def ema(source, target, decay):
280
+ source_dict = source.state_dict()
281
+ target_dict = target.state_dict()
282
+ for key in source_dict.keys():
283
+ target_dict[key].data.copy_(target_dict[key].data * decay +
284
+ source_dict[key].data * (1 - decay))
285
+
286
+
287
+ def train_cls(conf: TrainConfig, gpus):
288
+ print('conf:', conf.name)
289
+ model = ClsModel(conf)
290
+
291
+ if not os.path.exists(conf.logdir):
292
+ os.makedirs(conf.logdir)
293
+ checkpoint = ModelCheckpoint(
294
+ dirpath=f'{conf.logdir}',
295
+ save_last=True,
296
+ save_top_k=1,
297
+ # every_n_train_steps=conf.save_every_samples //
298
+ # conf.batch_size_effective,
299
+ )
300
+ checkpoint_path = f'{conf.logdir}/last.ckpt'
301
+ if os.path.exists(checkpoint_path):
302
+ resume = checkpoint_path
303
+ else:
304
+ if conf.continue_from is not None:
305
+ # continue from a checkpoint
306
+ resume = conf.continue_from.path
307
+ else:
308
+ resume = None
309
+
310
+ tb_logger = pl_loggers.TensorBoardLogger(save_dir=conf.logdir,
311
+ name=None,
312
+ version='')
313
+
314
+ # from pytorch_lightning.
315
+
316
+ plugins = []
317
+ if len(gpus) == 1:
318
+ accelerator = None
319
+ else:
320
+ accelerator = 'ddp'
321
+ from pytorch_lightning.plugins import DDPPlugin
322
+ # important for working with gradient checkpoint
323
+ plugins.append(DDPPlugin(find_unused_parameters=False))
324
+
325
+ trainer = pl.Trainer(
326
+ max_steps=conf.total_samples // conf.batch_size_effective,
327
+ resume_from_checkpoint=resume,
328
+ gpus=gpus,
329
+ accelerator=accelerator,
330
+ precision=16 if conf.fp16 else 32,
331
+ callbacks=[
332
+ checkpoint,
333
+ ],
334
+ replace_sampler_ddp=True,
335
+ logger=tb_logger,
336
+ accumulate_grad_batches=conf.accum_batches,
337
+ plugins=plugins,
338
+ )
339
+ trainer.fit(model)
imgs/sandy.JPG ADDED
imgs_align/sandy.png ADDED
imgs_interpolate/1_a.png ADDED
imgs_interpolate/1_b.png ADDED
imgs_manipulated/compare.png ADDED
imgs_manipulated/output.png ADDED
imgs_manipulated/sandy-wavyhair.png ADDED
install_requirements_for_colab.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 pytorch-lightning==1.2.2 torchtext==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
2
+ !pip install scipy==1.5.4
3
+ !pip install numpy==1.19.5
4
+ !pip install tqdm
5
+ !pip install pytorch-fid==0.2.0
6
+ !pip install pandas==1.1.5
7
+ !pip install lpips==0.1.4
8
+ !pip install lmdb==1.2.1
9
+ !pip install ftfy
10
+ !pip install regex
11
+ !pip install dlib requests
interpolate.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
lmdb_writer.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+
3
+ import lmdb
4
+ from PIL import Image
5
+
6
+ import torch
7
+
8
+ from contextlib import contextmanager
9
+ from torch.utils.data import Dataset
10
+ from multiprocessing import Process, Queue
11
+ import os
12
+ import shutil
13
+
14
+
15
+ def convert(x, format, quality=100):
16
+ # to prevent locking!
17
+ torch.set_num_threads(1)
18
+
19
+ buffer = BytesIO()
20
+ x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0)
21
+ x = x.to(torch.uint8)
22
+ x = x.numpy()
23
+ img = Image.fromarray(x)
24
+ img.save(buffer, format=format, quality=quality)
25
+ val = buffer.getvalue()
26
+ return val
27
+
28
+
29
+ @contextmanager
30
+ def nullcontext():
31
+ yield
32
+
33
+
34
+ class _WriterWroker(Process):
35
+ def __init__(self, path, format, quality, zfill, q):
36
+ super().__init__()
37
+ if os.path.exists(path):
38
+ shutil.rmtree(path)
39
+
40
+ self.path = path
41
+ self.format = format
42
+ self.quality = quality
43
+ self.zfill = zfill
44
+ self.q = q
45
+ self.i = 0
46
+
47
+ def run(self):
48
+ if not os.path.exists(self.path):
49
+ os.makedirs(self.path)
50
+
51
+ with lmdb.open(self.path, map_size=1024**4, readahead=False) as env:
52
+ while True:
53
+ job = self.q.get()
54
+ if job is None:
55
+ break
56
+ with env.begin(write=True) as txn:
57
+ for x in job:
58
+ key = f"{str(self.i).zfill(self.zfill)}".encode(
59
+ "utf-8")
60
+ x = convert(x, self.format, self.quality)
61
+ txn.put(key, x)
62
+ self.i += 1
63
+
64
+ with env.begin(write=True) as txn:
65
+ txn.put("length".encode("utf-8"), str(self.i).encode("utf-8"))
66
+
67
+
68
+ class LMDBImageWriter:
69
+ def __init__(self, path, format='webp', quality=100, zfill=7) -> None:
70
+ self.path = path
71
+ self.format = format
72
+ self.quality = quality
73
+ self.zfill = zfill
74
+ self.queue = None
75
+ self.worker = None
76
+
77
+ def __enter__(self):
78
+ self.queue = Queue(maxsize=3)
79
+ self.worker = _WriterWroker(self.path, self.format, self.quality,
80
+ self.zfill, self.queue)
81
+ self.worker.start()
82
+
83
+ def put_images(self, tensor):
84
+ """
85
+ Args:
86
+ tensor: (n, c, h, w) [0-1] tensor
87
+ """
88
+ self.queue.put(tensor.cpu())
89
+ # with self.env.begin(write=True) as txn:
90
+ # for x in tensor:
91
+ # key = f"{str(self.i).zfill(self.zfill)}".encode("utf-8")
92
+ # x = convert(x, self.format, self.quality)
93
+ # txn.put(key, x)
94
+ # self.i += 1
95
+
96
+ def __exit__(self, *args, **kwargs):
97
+ self.queue.put(None)
98
+ self.queue.close()
99
+ self.worker.join()
100
+
101
+
102
+ class LMDBImageReader(Dataset):
103
+ def __init__(self, path, zfill: int = 7):
104
+ self.zfill = zfill
105
+ self.env = lmdb.open(
106
+ path,
107
+ max_readers=32,
108
+ readonly=True,
109
+ lock=False,
110
+ readahead=False,
111
+ meminit=False,
112
+ )
113
+
114
+ if not self.env:
115
+ raise IOError('Cannot open lmdb dataset', path)
116
+
117
+ with self.env.begin(write=False) as txn:
118
+ self.length = int(
119
+ txn.get('length'.encode('utf-8')).decode('utf-8'))
120
+
121
+ def __len__(self):
122
+ return self.length
123
+
124
+ def __getitem__(self, index):
125
+ with self.env.begin(write=False) as txn:
126
+ key = f'{str(index).zfill(self.zfill)}'.encode('utf-8')
127
+ img_bytes = txn.get(key)
128
+
129
+ buffer = BytesIO(img_bytes)
130
+ img = Image.open(buffer)
131
+ return img
manipulate.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
manipulate_note.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
metrics.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ import torch
5
+ import torchvision
6
+ from pytorch_fid import fid_score
7
+ from torch import distributed
8
+ from torch.utils.data import DataLoader
9
+ from torch.utils.data.distributed import DistributedSampler
10
+ from tqdm.autonotebook import tqdm, trange
11
+
12
+ from renderer import *
13
+ from config import *
14
+ from diffusion import Sampler
15
+ from dist_utils import *
16
+ import lpips
17
+ from ssim import ssim
18
+
19
+
20
+ def make_subset_loader(conf: TrainConfig,
21
+ dataset: Dataset,
22
+ batch_size: int,
23
+ shuffle: bool,
24
+ parallel: bool,
25
+ drop_last=True):
26
+ dataset = SubsetDataset(dataset, size=conf.eval_num_images)
27
+ if parallel and distributed.is_initialized():
28
+ sampler = DistributedSampler(dataset, shuffle=shuffle)
29
+ else:
30
+ sampler = None
31
+ return DataLoader(
32
+ dataset,
33
+ batch_size=batch_size,
34
+ sampler=sampler,
35
+ # with sampler, use the sample instead of this option
36
+ shuffle=False if sampler else shuffle,
37
+ num_workers=conf.num_workers,
38
+ pin_memory=True,
39
+ drop_last=drop_last,
40
+ multiprocessing_context=get_context('fork'),
41
+ )
42
+
43
+
44
+ def evaluate_lpips(
45
+ sampler: Sampler,
46
+ model: Model,
47
+ conf: TrainConfig,
48
+ device,
49
+ val_data: Dataset,
50
+ latent_sampler: Sampler = None,
51
+ use_inverted_noise: bool = False,
52
+ ):
53
+ """
54
+ compare the generated images from autoencoder on validation dataset
55
+
56
+ Args:
57
+ use_inversed_noise: the noise is also inverted from DDIM
58
+ """
59
+ lpips_fn = lpips.LPIPS(net='alex').to(device)
60
+ val_loader = make_subset_loader(conf,
61
+ dataset=val_data,
62
+ batch_size=conf.batch_size_eval,
63
+ shuffle=False,
64
+ parallel=True)
65
+
66
+ model.eval()
67
+ with torch.no_grad():
68
+ scores = {
69
+ 'lpips': [],
70
+ 'mse': [],
71
+ 'ssim': [],
72
+ 'psnr': [],
73
+ }
74
+ for batch in tqdm(val_loader, desc='lpips'):
75
+ imgs = batch['img'].to(device)
76
+
77
+ if use_inverted_noise:
78
+ # inverse the noise
79
+ # with condition from the encoder
80
+ model_kwargs = {}
81
+ if conf.model_type.has_autoenc():
82
+ with torch.no_grad():
83
+ model_kwargs = model.encode(imgs)
84
+ x_T = sampler.ddim_reverse_sample_loop(
85
+ model=model,
86
+ x=imgs,
87
+ clip_denoised=True,
88
+ model_kwargs=model_kwargs)
89
+ x_T = x_T['sample']
90
+ else:
91
+ x_T = torch.randn((len(imgs), 3, conf.img_size, conf.img_size),
92
+ device=device)
93
+
94
+ if conf.model_type == ModelType.ddpm:
95
+ # the case where you want to calculate the inversion capability of the DDIM model
96
+ assert use_inverted_noise
97
+ pred_imgs = render_uncondition(
98
+ conf=conf,
99
+ model=model,
100
+ x_T=x_T,
101
+ sampler=sampler,
102
+ latent_sampler=latent_sampler,
103
+ )
104
+ else:
105
+ pred_imgs = render_condition(conf=conf,
106
+ model=model,
107
+ x_T=x_T,
108
+ x_start=imgs,
109
+ cond=None,
110
+ sampler=sampler)
111
+ # # returns {'cond', 'cond2'}
112
+ # conds = model.encode(imgs)
113
+ # pred_imgs = sampler.sample(model=model,
114
+ # noise=x_T,
115
+ # model_kwargs=conds)
116
+
117
+ # (n, 1, 1, 1) => (n, )
118
+ scores['lpips'].append(lpips_fn.forward(imgs, pred_imgs).view(-1))
119
+
120
+ # need to normalize into [0, 1]
121
+ norm_imgs = (imgs + 1) / 2
122
+ norm_pred_imgs = (pred_imgs + 1) / 2
123
+ # (n, )
124
+ scores['ssim'].append(
125
+ ssim(norm_imgs, norm_pred_imgs, size_average=False))
126
+ # (n, )
127
+ scores['mse'].append(
128
+ (norm_imgs - norm_pred_imgs).pow(2).mean(dim=[1, 2, 3]))
129
+ # (n, )
130
+ scores['psnr'].append(psnr(norm_imgs, norm_pred_imgs))
131
+ # (N, )
132
+ for key in scores.keys():
133
+ scores[key] = torch.cat(scores[key]).float()
134
+ model.train()
135
+
136
+ barrier()
137
+
138
+ # support multi-gpu
139
+ outs = {
140
+ key: [
141
+ torch.zeros(len(scores[key]), device=device)
142
+ for i in range(get_world_size())
143
+ ]
144
+ for key in scores.keys()
145
+ }
146
+ for key in scores.keys():
147
+ all_gather(outs[key], scores[key])
148
+
149
+ # final scores
150
+ for key in scores.keys():
151
+ scores[key] = torch.cat(outs[key]).mean().item()
152
+
153
+ # {'lpips', 'mse', 'ssim'}
154
+ return scores
155
+
156
+
157
+ def psnr(img1, img2):
158
+ """
159
+ Args:
160
+ img1: (n, c, h, w)
161
+ """
162
+ v_max = 1.
163
+ # (n,)
164
+ mse = torch.mean((img1 - img2)**2, dim=[1, 2, 3])
165
+ return 20 * torch.log10(v_max / torch.sqrt(mse))
166
+
167
+
168
+ def evaluate_fid(
169
+ sampler: Sampler,
170
+ model: Model,
171
+ conf: TrainConfig,
172
+ device,
173
+ train_data: Dataset,
174
+ val_data: Dataset,
175
+ latent_sampler: Sampler = None,
176
+ conds_mean=None,
177
+ conds_std=None,
178
+ remove_cache: bool = True,
179
+ clip_latent_noise: bool = False,
180
+ ):
181
+ assert conf.fid_cache is not None
182
+ if get_rank() == 0:
183
+ # no parallel
184
+ # validation data for a comparing FID
185
+ val_loader = make_subset_loader(conf,
186
+ dataset=val_data,
187
+ batch_size=conf.batch_size_eval,
188
+ shuffle=False,
189
+ parallel=False)
190
+
191
+ # put the val images to a directory
192
+ cache_dir = f'{conf.fid_cache}_{conf.eval_num_images}'
193
+ if (os.path.exists(cache_dir)
194
+ and len(os.listdir(cache_dir)) < conf.eval_num_images):
195
+ shutil.rmtree(cache_dir)
196
+
197
+ if not os.path.exists(cache_dir):
198
+ # write files to the cache
199
+ # the images are normalized, hence need to denormalize first
200
+ loader_to_path(val_loader, cache_dir, denormalize=True)
201
+
202
+ # create the generate dir
203
+ if os.path.exists(conf.generate_dir):
204
+ shutil.rmtree(conf.generate_dir)
205
+ os.makedirs(conf.generate_dir)
206
+
207
+ barrier()
208
+
209
+ world_size = get_world_size()
210
+ rank = get_rank()
211
+ batch_size = chunk_size(conf.batch_size_eval, rank, world_size)
212
+
213
+ def filename(idx):
214
+ return world_size * idx + rank
215
+
216
+ model.eval()
217
+ with torch.no_grad():
218
+ if conf.model_type.can_sample():
219
+ eval_num_images = chunk_size(conf.eval_num_images, rank,
220
+ world_size)
221
+ desc = "generating images"
222
+ for i in trange(0, eval_num_images, batch_size, desc=desc):
223
+ batch_size = min(batch_size, eval_num_images - i)
224
+ x_T = torch.randn(
225
+ (batch_size, 3, conf.img_size, conf.img_size),
226
+ device=device)
227
+ batch_images = render_uncondition(
228
+ conf=conf,
229
+ model=model,
230
+ x_T=x_T,
231
+ sampler=sampler,
232
+ latent_sampler=latent_sampler,
233
+ conds_mean=conds_mean,
234
+ conds_std=conds_std).cpu()
235
+
236
+ batch_images = (batch_images + 1) / 2
237
+ # keep the generated images
238
+ for j in range(len(batch_images)):
239
+ img_name = filename(i + j)
240
+ torchvision.utils.save_image(
241
+ batch_images[j],
242
+ os.path.join(conf.generate_dir, f'{img_name}.png'))
243
+ elif conf.model_type == ModelType.autoencoder:
244
+ if conf.train_mode.is_latent_diffusion():
245
+ # evaluate autoencoder + latent diffusion (doesn't give the images)
246
+ model: BeatGANsAutoencModel
247
+ eval_num_images = chunk_size(conf.eval_num_images, rank,
248
+ world_size)
249
+ desc = "generating images"
250
+ for i in trange(0, eval_num_images, batch_size, desc=desc):
251
+ batch_size = min(batch_size, eval_num_images - i)
252
+ x_T = torch.randn(
253
+ (batch_size, 3, conf.img_size, conf.img_size),
254
+ device=device)
255
+ batch_images = render_uncondition(
256
+ conf=conf,
257
+ model=model,
258
+ x_T=x_T,
259
+ sampler=sampler,
260
+ latent_sampler=latent_sampler,
261
+ conds_mean=conds_mean,
262
+ conds_std=conds_std,
263
+ clip_latent_noise=clip_latent_noise,
264
+ ).cpu()
265
+ batch_images = (batch_images + 1) / 2
266
+ # keep the generated images
267
+ for j in range(len(batch_images)):
268
+ img_name = filename(i + j)
269
+ torchvision.utils.save_image(
270
+ batch_images[j],
271
+ os.path.join(conf.generate_dir, f'{img_name}.png'))
272
+ else:
273
+ # evaulate autoencoder (given the images)
274
+ # to make the FID fair, autoencoder must not see the validation dataset
275
+ # also shuffle to make it closer to unconditional generation
276
+ train_loader = make_subset_loader(conf,
277
+ dataset=train_data,
278
+ batch_size=batch_size,
279
+ shuffle=True,
280
+ parallel=True)
281
+
282
+ i = 0
283
+ for batch in tqdm(train_loader, desc='generating images'):
284
+ imgs = batch['img'].to(device)
285
+ x_T = torch.randn(
286
+ (len(imgs), 3, conf.img_size, conf.img_size),
287
+ device=device)
288
+ batch_images = render_condition(
289
+ conf=conf,
290
+ model=model,
291
+ x_T=x_T,
292
+ x_start=imgs,
293
+ cond=None,
294
+ sampler=sampler,
295
+ latent_sampler=latent_sampler).cpu()
296
+ # model: BeatGANsAutoencModel
297
+ # # returns {'cond', 'cond2'}
298
+ # conds = model.encode(imgs)
299
+ # batch_images = sampler.sample(model=model,
300
+ # noise=x_T,
301
+ # model_kwargs=conds).cpu()
302
+ # denormalize the images
303
+ batch_images = (batch_images + 1) / 2
304
+ # keep the generated images
305
+ for j in range(len(batch_images)):
306
+ img_name = filename(i + j)
307
+ torchvision.utils.save_image(
308
+ batch_images[j],
309
+ os.path.join(conf.generate_dir, f'{img_name}.png'))
310
+ i += len(imgs)
311
+ else:
312
+ raise NotImplementedError()
313
+ model.train()
314
+
315
+ barrier()
316
+
317
+ if get_rank() == 0:
318
+ fid = fid_score.calculate_fid_given_paths(
319
+ [cache_dir, conf.generate_dir],
320
+ batch_size,
321
+ device=device,
322
+ dims=2048)
323
+
324
+ # remove the cache
325
+ if remove_cache and os.path.exists(conf.generate_dir):
326
+ shutil.rmtree(conf.generate_dir)
327
+
328
+ barrier()
329
+
330
+ if get_rank() == 0:
331
+ # need to float it! unless the broadcasted value is wrong
332
+ fid = torch.tensor(float(fid), device=device)
333
+ broadcast(fid, 0)
334
+ else:
335
+ fid = torch.tensor(0., device=device)
336
+ broadcast(fid, 0)
337
+ fid = fid.item()
338
+ print(f'fid ({get_rank()}):', fid)
339
+
340
+ return fid
341
+
342
+
343
+ def loader_to_path(loader: DataLoader, path: str, denormalize: bool):
344
+ # not process safe!
345
+
346
+ if not os.path.exists(path):
347
+ os.makedirs(path)
348
+
349
+ # write the loader to files
350
+ i = 0
351
+ for batch in tqdm(loader, desc='copy images'):
352
+ imgs = batch['img']
353
+ if denormalize:
354
+ imgs = (imgs + 1) / 2
355
+ for j in range(len(imgs)):
356
+ torchvision.utils.save_image(imgs[j],
357
+ os.path.join(path, f'{i+j}.png'))
358
+ i += len(imgs)
model/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ from .unet import BeatGANsUNetModel, BeatGANsUNetConfig
3
+ from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel
4
+
5
+ Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel]
6
+ ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig]
model/blocks.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from abc import abstractmethod
3
+ from dataclasses import dataclass
4
+ from numbers import Number
5
+
6
+ import torch as th
7
+ import torch.nn.functional as F
8
+ from choices import *
9
+ from config_base import BaseConfig
10
+ from torch import nn
11
+
12
+ from .nn import (avg_pool_nd, conv_nd, linear, normalization,
13
+ timestep_embedding, torch_checkpoint, zero_module)
14
+
15
+
16
+ class ScaleAt(Enum):
17
+ after_norm = 'afternorm'
18
+
19
+
20
+ class TimestepBlock(nn.Module):
21
+ """
22
+ Any module where forward() takes timestep embeddings as a second argument.
23
+ """
24
+ @abstractmethod
25
+ def forward(self, x, emb=None, cond=None, lateral=None):
26
+ """
27
+ Apply the module to `x` given `emb` timestep embeddings.
28
+ """
29
+
30
+
31
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
32
+ """
33
+ A sequential module that passes timestep embeddings to the children that
34
+ support it as an extra input.
35
+ """
36
+ def forward(self, x, emb=None, cond=None, lateral=None):
37
+ for layer in self:
38
+ if isinstance(layer, TimestepBlock):
39
+ x = layer(x, emb=emb, cond=cond, lateral=lateral)
40
+ else:
41
+ x = layer(x)
42
+ return x
43
+
44
+
45
+ @dataclass
46
+ class ResBlockConfig(BaseConfig):
47
+ channels: int
48
+ emb_channels: int
49
+ dropout: float
50
+ out_channels: int = None
51
+ # condition the resblock with time (and encoder's output)
52
+ use_condition: bool = True
53
+ # whether to use 3x3 conv for skip path when the channels aren't matched
54
+ use_conv: bool = False
55
+ # dimension of conv (always 2 = 2d)
56
+ dims: int = 2
57
+ # gradient checkpoint
58
+ use_checkpoint: bool = False
59
+ up: bool = False
60
+ down: bool = False
61
+ # whether to condition with both time & encoder's output
62
+ two_cond: bool = False
63
+ # number of encoders' output channels
64
+ cond_emb_channels: int = None
65
+ # suggest: False
66
+ has_lateral: bool = False
67
+ lateral_channels: int = None
68
+ # whether to init the convolution with zero weights
69
+ # this is default from BeatGANs and seems to help learning
70
+ use_zero_module: bool = True
71
+
72
+ def __post_init__(self):
73
+ self.out_channels = self.out_channels or self.channels
74
+ self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
75
+
76
+ def make_model(self):
77
+ return ResBlock(self)
78
+
79
+
80
+ class ResBlock(TimestepBlock):
81
+ """
82
+ A residual block that can optionally change the number of channels.
83
+
84
+ total layers:
85
+ in_layers
86
+ - norm
87
+ - act
88
+ - conv
89
+ out_layers
90
+ - norm
91
+ - (modulation)
92
+ - act
93
+ - conv
94
+ """
95
+ def __init__(self, conf: ResBlockConfig):
96
+ super().__init__()
97
+ self.conf = conf
98
+
99
+ #############################
100
+ # IN LAYERS
101
+ #############################
102
+ assert conf.lateral_channels is None
103
+ layers = [
104
+ normalization(conf.channels),
105
+ nn.SiLU(),
106
+ conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1)
107
+ ]
108
+ self.in_layers = nn.Sequential(*layers)
109
+
110
+ self.updown = conf.up or conf.down
111
+
112
+ if conf.up:
113
+ self.h_upd = Upsample(conf.channels, False, conf.dims)
114
+ self.x_upd = Upsample(conf.channels, False, conf.dims)
115
+ elif conf.down:
116
+ self.h_upd = Downsample(conf.channels, False, conf.dims)
117
+ self.x_upd = Downsample(conf.channels, False, conf.dims)
118
+ else:
119
+ self.h_upd = self.x_upd = nn.Identity()
120
+
121
+ #############################
122
+ # OUT LAYERS CONDITIONS
123
+ #############################
124
+ if conf.use_condition:
125
+ # condition layers for the out_layers
126
+ self.emb_layers = nn.Sequential(
127
+ nn.SiLU(),
128
+ linear(conf.emb_channels, 2 * conf.out_channels),
129
+ )
130
+
131
+ if conf.two_cond:
132
+ self.cond_emb_layers = nn.Sequential(
133
+ nn.SiLU(),
134
+ linear(conf.cond_emb_channels, conf.out_channels),
135
+ )
136
+ #############################
137
+ # OUT LAYERS (ignored when there is no condition)
138
+ #############################
139
+ # original version
140
+ conv = conv_nd(conf.dims,
141
+ conf.out_channels,
142
+ conf.out_channels,
143
+ 3,
144
+ padding=1)
145
+ if conf.use_zero_module:
146
+ # zere out the weights
147
+ # it seems to help training
148
+ conv = zero_module(conv)
149
+
150
+ # construct the layers
151
+ # - norm
152
+ # - (modulation)
153
+ # - act
154
+ # - dropout
155
+ # - conv
156
+ layers = []
157
+ layers += [
158
+ normalization(conf.out_channels),
159
+ nn.SiLU(),
160
+ nn.Dropout(p=conf.dropout),
161
+ conv,
162
+ ]
163
+ self.out_layers = nn.Sequential(*layers)
164
+
165
+ #############################
166
+ # SKIP LAYERS
167
+ #############################
168
+ if conf.out_channels == conf.channels:
169
+ # cannot be used with gatedconv, also gatedconv is alsways used as the first block
170
+ self.skip_connection = nn.Identity()
171
+ else:
172
+ if conf.use_conv:
173
+ kernel_size = 3
174
+ padding = 1
175
+ else:
176
+ kernel_size = 1
177
+ padding = 0
178
+
179
+ self.skip_connection = conv_nd(conf.dims,
180
+ conf.channels,
181
+ conf.out_channels,
182
+ kernel_size,
183
+ padding=padding)
184
+
185
+ def forward(self, x, emb=None, cond=None, lateral=None):
186
+ """
187
+ Apply the block to a Tensor, conditioned on a timestep embedding.
188
+
189
+ Args:
190
+ x: input
191
+ lateral: lateral connection from the encoder
192
+ """
193
+ return torch_checkpoint(self._forward, (x, emb, cond, lateral),
194
+ self.conf.use_checkpoint)
195
+
196
+ def _forward(
197
+ self,
198
+ x,
199
+ emb=None,
200
+ cond=None,
201
+ lateral=None,
202
+ ):
203
+ """
204
+ Args:
205
+ lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally
206
+ """
207
+ if self.conf.has_lateral:
208
+ # lateral may be supplied even if it doesn't require
209
+ # the model will take the lateral only if "has_lateral"
210
+ assert lateral is not None
211
+ x = th.cat([x, lateral], dim=1)
212
+
213
+ if self.updown:
214
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
215
+ h = in_rest(x)
216
+ h = self.h_upd(h)
217
+ x = self.x_upd(x)
218
+ h = in_conv(h)
219
+ else:
220
+ h = self.in_layers(x)
221
+
222
+ if self.conf.use_condition:
223
+ # it's possible that the network may not receieve the time emb
224
+ # this happens with autoenc and setting the time_at
225
+ if emb is not None:
226
+ emb_out = self.emb_layers(emb).type(h.dtype)
227
+ else:
228
+ emb_out = None
229
+
230
+ if self.conf.two_cond:
231
+ # it's possible that the network is two_cond
232
+ # but it doesn't get the second condition
233
+ # in which case, we ignore the second condition
234
+ # and treat as if the network has one condition
235
+ if cond is None:
236
+ cond_out = None
237
+ else:
238
+ cond_out = self.cond_emb_layers(cond).type(h.dtype)
239
+
240
+ if cond_out is not None:
241
+ while len(cond_out.shape) < len(h.shape):
242
+ cond_out = cond_out[..., None]
243
+ else:
244
+ cond_out = None
245
+
246
+ # this is the new refactored code
247
+ h = apply_conditions(
248
+ h=h,
249
+ emb=emb_out,
250
+ cond=cond_out,
251
+ layers=self.out_layers,
252
+ scale_bias=1,
253
+ in_channels=self.conf.out_channels,
254
+ up_down_layer=None,
255
+ )
256
+
257
+ return self.skip_connection(x) + h
258
+
259
+
260
+ def apply_conditions(
261
+ h,
262
+ emb=None,
263
+ cond=None,
264
+ layers: nn.Sequential = None,
265
+ scale_bias: float = 1,
266
+ in_channels: int = 512,
267
+ up_down_layer: nn.Module = None,
268
+ ):
269
+ """
270
+ apply conditions on the feature maps
271
+
272
+ Args:
273
+ emb: time conditional (ready to scale + shift)
274
+ cond: encoder's conditional (read to scale + shift)
275
+ """
276
+ two_cond = emb is not None and cond is not None
277
+
278
+ if emb is not None:
279
+ # adjusting shapes
280
+ while len(emb.shape) < len(h.shape):
281
+ emb = emb[..., None]
282
+
283
+ if two_cond:
284
+ # adjusting shapes
285
+ while len(cond.shape) < len(h.shape):
286
+ cond = cond[..., None]
287
+ # time first
288
+ scale_shifts = [emb, cond]
289
+ else:
290
+ # "cond" is not used with single cond mode
291
+ scale_shifts = [emb]
292
+
293
+ # support scale, shift or shift only
294
+ for i, each in enumerate(scale_shifts):
295
+ if each is None:
296
+ # special case: the condition is not provided
297
+ a = None
298
+ b = None
299
+ else:
300
+ if each.shape[1] == in_channels * 2:
301
+ a, b = th.chunk(each, 2, dim=1)
302
+ else:
303
+ a = each
304
+ b = None
305
+ scale_shifts[i] = (a, b)
306
+
307
+ # condition scale bias could be a list
308
+ if isinstance(scale_bias, Number):
309
+ biases = [scale_bias] * len(scale_shifts)
310
+ else:
311
+ # a list
312
+ biases = scale_bias
313
+
314
+ # default, the scale & shift are applied after the group norm but BEFORE SiLU
315
+ pre_layers, post_layers = layers[0], layers[1:]
316
+
317
+ # spilt the post layer to be able to scale up or down before conv
318
+ # post layers will contain only the conv
319
+ mid_layers, post_layers = post_layers[:-2], post_layers[-2:]
320
+
321
+ h = pre_layers(h)
322
+ # scale and shift for each condition
323
+ for i, (scale, shift) in enumerate(scale_shifts):
324
+ # if scale is None, it indicates that the condition is not provided
325
+ if scale is not None:
326
+ h = h * (biases[i] + scale)
327
+ if shift is not None:
328
+ h = h + shift
329
+ h = mid_layers(h)
330
+
331
+ # upscale or downscale if any just before the last conv
332
+ if up_down_layer is not None:
333
+ h = up_down_layer(h)
334
+ h = post_layers(h)
335
+ return h
336
+
337
+
338
+ class Upsample(nn.Module):
339
+ """
340
+ An upsampling layer with an optional convolution.
341
+
342
+ :param channels: channels in the inputs and outputs.
343
+ :param use_conv: a bool determining if a convolution is applied.
344
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
345
+ upsampling occurs in the inner-two dimensions.
346
+ """
347
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
348
+ super().__init__()
349
+ self.channels = channels
350
+ self.out_channels = out_channels or channels
351
+ self.use_conv = use_conv
352
+ self.dims = dims
353
+ if use_conv:
354
+ self.conv = conv_nd(dims,
355
+ self.channels,
356
+ self.out_channels,
357
+ 3,
358
+ padding=1)
359
+
360
+ def forward(self, x):
361
+ assert x.shape[1] == self.channels
362
+ if self.dims == 3:
363
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
364
+ mode="nearest")
365
+ else:
366
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
367
+ if self.use_conv:
368
+ x = self.conv(x)
369
+ return x
370
+
371
+
372
+ class Downsample(nn.Module):
373
+ """
374
+ A downsampling layer with an optional convolution.
375
+
376
+ :param channels: channels in the inputs and outputs.
377
+ :param use_conv: a bool determining if a convolution is applied.
378
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
379
+ downsampling occurs in the inner-two dimensions.
380
+ """
381
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
382
+ super().__init__()
383
+ self.channels = channels
384
+ self.out_channels = out_channels or channels
385
+ self.use_conv = use_conv
386
+ self.dims = dims
387
+ stride = 2 if dims != 3 else (1, 2, 2)
388
+ if use_conv:
389
+ self.op = conv_nd(dims,
390
+ self.channels,
391
+ self.out_channels,
392
+ 3,
393
+ stride=stride,
394
+ padding=1)
395
+ else:
396
+ assert self.channels == self.out_channels
397
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
398
+
399
+ def forward(self, x):
400
+ assert x.shape[1] == self.channels
401
+ return self.op(x)
402
+
403
+
404
+ class AttentionBlock(nn.Module):
405
+ """
406
+ An attention block that allows spatial positions to attend to each other.
407
+
408
+ Originally ported from here, but adapted to the N-d case.
409
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
410
+ """
411
+ def __init__(
412
+ self,
413
+ channels,
414
+ num_heads=1,
415
+ num_head_channels=-1,
416
+ use_checkpoint=False,
417
+ use_new_attention_order=False,
418
+ ):
419
+ super().__init__()
420
+ self.channels = channels
421
+ if num_head_channels == -1:
422
+ self.num_heads = num_heads
423
+ else:
424
+ assert (
425
+ channels % num_head_channels == 0
426
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
427
+ self.num_heads = channels // num_head_channels
428
+ self.use_checkpoint = use_checkpoint
429
+ self.norm = normalization(channels)
430
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
431
+ if use_new_attention_order:
432
+ # split qkv before split heads
433
+ self.attention = QKVAttention(self.num_heads)
434
+ else:
435
+ # split heads before split qkv
436
+ self.attention = QKVAttentionLegacy(self.num_heads)
437
+
438
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
439
+
440
+ def forward(self, x):
441
+ return torch_checkpoint(self._forward, (x, ), self.use_checkpoint)
442
+
443
+ def _forward(self, x):
444
+ b, c, *spatial = x.shape
445
+ x = x.reshape(b, c, -1)
446
+ qkv = self.qkv(self.norm(x))
447
+ h = self.attention(qkv)
448
+ h = self.proj_out(h)
449
+ return (x + h).reshape(b, c, *spatial)
450
+
451
+
452
+ def count_flops_attn(model, _x, y):
453
+ """
454
+ A counter for the `thop` package to count the operations in an
455
+ attention operation.
456
+ Meant to be used like:
457
+ macs, params = thop.profile(
458
+ model,
459
+ inputs=(inputs, timestamps),
460
+ custom_ops={QKVAttention: QKVAttention.count_flops},
461
+ )
462
+ """
463
+ b, c, *spatial = y[0].shape
464
+ num_spatial = int(np.prod(spatial))
465
+ # We perform two matmuls with the same number of ops.
466
+ # The first computes the weight matrix, the second computes
467
+ # the combination of the value vectors.
468
+ matmul_ops = 2 * b * (num_spatial**2) * c
469
+ model.total_ops += th.DoubleTensor([matmul_ops])
470
+
471
+
472
+ class QKVAttentionLegacy(nn.Module):
473
+ """
474
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
475
+ """
476
+ def __init__(self, n_heads):
477
+ super().__init__()
478
+ self.n_heads = n_heads
479
+
480
+ def forward(self, qkv):
481
+ """
482
+ Apply QKV attention.
483
+
484
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
485
+ :return: an [N x (H * C) x T] tensor after attention.
486
+ """
487
+ bs, width, length = qkv.shape
488
+ assert width % (3 * self.n_heads) == 0
489
+ ch = width // (3 * self.n_heads)
490
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch,
491
+ dim=1)
492
+ scale = 1 / math.sqrt(math.sqrt(ch))
493
+ weight = th.einsum(
494
+ "bct,bcs->bts", q * scale,
495
+ k * scale) # More stable with f16 than dividing afterwards
496
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
497
+ a = th.einsum("bts,bcs->bct", weight, v)
498
+ return a.reshape(bs, -1, length)
499
+
500
+ @staticmethod
501
+ def count_flops(model, _x, y):
502
+ return count_flops_attn(model, _x, y)
503
+
504
+
505
+ class QKVAttention(nn.Module):
506
+ """
507
+ A module which performs QKV attention and splits in a different order.
508
+ """
509
+ def __init__(self, n_heads):
510
+ super().__init__()
511
+ self.n_heads = n_heads
512
+
513
+ def forward(self, qkv):
514
+ """
515
+ Apply QKV attention.
516
+
517
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
518
+ :return: an [N x (H * C) x T] tensor after attention.
519
+ """
520
+ bs, width, length = qkv.shape
521
+ assert width % (3 * self.n_heads) == 0
522
+ ch = width // (3 * self.n_heads)
523
+ q, k, v = qkv.chunk(3, dim=1)
524
+ scale = 1 / math.sqrt(math.sqrt(ch))
525
+ weight = th.einsum(
526
+ "bct,bcs->bts",
527
+ (q * scale).view(bs * self.n_heads, ch, length),
528
+ (k * scale).view(bs * self.n_heads, ch, length),
529
+ ) # More stable with f16 than dividing afterwards
530
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
531
+ a = th.einsum("bts,bcs->bct", weight,
532
+ v.reshape(bs * self.n_heads, ch, length))
533
+ return a.reshape(bs, -1, length)
534
+
535
+ @staticmethod
536
+ def count_flops(model, _x, y):
537
+ return count_flops_attn(model, _x, y)
538
+
539
+
540
+ class AttentionPool2d(nn.Module):
541
+ """
542
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
543
+ """
544
+ def __init__(
545
+ self,
546
+ spacial_dim: int,
547
+ embed_dim: int,
548
+ num_heads_channels: int,
549
+ output_dim: int = None,
550
+ ):
551
+ super().__init__()
552
+ self.positional_embedding = nn.Parameter(
553
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
554
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
555
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
556
+ self.num_heads = embed_dim // num_heads_channels
557
+ self.attention = QKVAttention(self.num_heads)
558
+
559
+ def forward(self, x):
560
+ b, c, *_spatial = x.shape
561
+ x = x.reshape(b, c, -1) # NC(HW)
562
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
563
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
564
+ x = self.qkv_proj(x)
565
+ x = self.attention(x)
566
+ x = self.c_proj(x)
567
+ return x[:, :, 0]
model/latentnet.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+ from typing import NamedTuple, Tuple
5
+
6
+ import torch
7
+ from choices import *
8
+ from config_base import BaseConfig
9
+ from torch import nn
10
+ from torch.nn import init
11
+
12
+ from .blocks import *
13
+ from .nn import timestep_embedding
14
+ from .unet import *
15
+
16
+
17
+ class LatentNetType(Enum):
18
+ none = 'none'
19
+ # injecting inputs into the hidden layers
20
+ skip = 'skip'
21
+
22
+
23
+ class LatentNetReturn(NamedTuple):
24
+ pred: torch.Tensor = None
25
+
26
+
27
+ @dataclass
28
+ class MLPSkipNetConfig(BaseConfig):
29
+ """
30
+ default MLP for the latent DPM in the paper!
31
+ """
32
+ num_channels: int
33
+ skip_layers: Tuple[int]
34
+ num_hid_channels: int
35
+ num_layers: int
36
+ num_time_emb_channels: int = 64
37
+ activation: Activation = Activation.silu
38
+ use_norm: bool = True
39
+ condition_bias: float = 1
40
+ dropout: float = 0
41
+ last_act: Activation = Activation.none
42
+ num_time_layers: int = 2
43
+ time_last_act: bool = False
44
+
45
+ def make_model(self):
46
+ return MLPSkipNet(self)
47
+
48
+
49
+ class MLPSkipNet(nn.Module):
50
+ """
51
+ concat x to hidden layers
52
+
53
+ default MLP for the latent DPM in the paper!
54
+ """
55
+ def __init__(self, conf: MLPSkipNetConfig):
56
+ super().__init__()
57
+ self.conf = conf
58
+
59
+ layers = []
60
+ for i in range(conf.num_time_layers):
61
+ if i == 0:
62
+ a = conf.num_time_emb_channels
63
+ b = conf.num_channels
64
+ else:
65
+ a = conf.num_channels
66
+ b = conf.num_channels
67
+ layers.append(nn.Linear(a, b))
68
+ if i < conf.num_time_layers - 1 or conf.time_last_act:
69
+ layers.append(conf.activation.get_act())
70
+ self.time_embed = nn.Sequential(*layers)
71
+
72
+ self.layers = nn.ModuleList([])
73
+ for i in range(conf.num_layers):
74
+ if i == 0:
75
+ act = conf.activation
76
+ norm = conf.use_norm
77
+ cond = True
78
+ a, b = conf.num_channels, conf.num_hid_channels
79
+ dropout = conf.dropout
80
+ elif i == conf.num_layers - 1:
81
+ act = Activation.none
82
+ norm = False
83
+ cond = False
84
+ a, b = conf.num_hid_channels, conf.num_channels
85
+ dropout = 0
86
+ else:
87
+ act = conf.activation
88
+ norm = conf.use_norm
89
+ cond = True
90
+ a, b = conf.num_hid_channels, conf.num_hid_channels
91
+ dropout = conf.dropout
92
+
93
+ if i in conf.skip_layers:
94
+ a += conf.num_channels
95
+
96
+ self.layers.append(
97
+ MLPLNAct(
98
+ a,
99
+ b,
100
+ norm=norm,
101
+ activation=act,
102
+ cond_channels=conf.num_channels,
103
+ use_cond=cond,
104
+ condition_bias=conf.condition_bias,
105
+ dropout=dropout,
106
+ ))
107
+ self.last_act = conf.last_act.get_act()
108
+
109
+ def forward(self, x, t, **kwargs):
110
+ t = timestep_embedding(t, self.conf.num_time_emb_channels)
111
+ cond = self.time_embed(t)
112
+ h = x
113
+ for i in range(len(self.layers)):
114
+ if i in self.conf.skip_layers:
115
+ # injecting input into the hidden layers
116
+ h = torch.cat([h, x], dim=1)
117
+ h = self.layers[i].forward(x=h, cond=cond)
118
+ h = self.last_act(h)
119
+ return LatentNetReturn(h)
120
+
121
+
122
+ class MLPLNAct(nn.Module):
123
+ def __init__(
124
+ self,
125
+ in_channels: int,
126
+ out_channels: int,
127
+ norm: bool,
128
+ use_cond: bool,
129
+ activation: Activation,
130
+ cond_channels: int,
131
+ condition_bias: float = 0,
132
+ dropout: float = 0,
133
+ ):
134
+ super().__init__()
135
+ self.activation = activation
136
+ self.condition_bias = condition_bias
137
+ self.use_cond = use_cond
138
+
139
+ self.linear = nn.Linear(in_channels, out_channels)
140
+ self.act = activation.get_act()
141
+ if self.use_cond:
142
+ self.linear_emb = nn.Linear(cond_channels, out_channels)
143
+ self.cond_layers = nn.Sequential(self.act, self.linear_emb)
144
+ if norm:
145
+ self.norm = nn.LayerNorm(out_channels)
146
+ else:
147
+ self.norm = nn.Identity()
148
+
149
+ if dropout > 0:
150
+ self.dropout = nn.Dropout(p=dropout)
151
+ else:
152
+ self.dropout = nn.Identity()
153
+
154
+ self.init_weights()
155
+
156
+ def init_weights(self):
157
+ for module in self.modules():
158
+ if isinstance(module, nn.Linear):
159
+ if self.activation == Activation.relu:
160
+ init.kaiming_normal_(module.weight,
161
+ a=0,
162
+ nonlinearity='relu')
163
+ elif self.activation == Activation.lrelu:
164
+ init.kaiming_normal_(module.weight,
165
+ a=0.2,
166
+ nonlinearity='leaky_relu')
167
+ elif self.activation == Activation.silu:
168
+ init.kaiming_normal_(module.weight,
169
+ a=0,
170
+ nonlinearity='relu')
171
+ else:
172
+ # leave it as default
173
+ pass
174
+
175
+ def forward(self, x, cond=None):
176
+ x = self.linear(x)
177
+ if self.use_cond:
178
+ # (n, c) or (n, c * 2)
179
+ cond = self.cond_layers(cond)
180
+ cond = (cond, None)
181
+
182
+ # scale shift first
183
+ x = x * (self.condition_bias + cond[0])
184
+ if cond[1] is not None:
185
+ x = x + cond[1]
186
+ # then norm
187
+ x = self.norm(x)
188
+ else:
189
+ # no condition
190
+ x = self.norm(x)
191
+ x = self.act(x)
192
+ x = self.dropout(x)
193
+ return x
model/nn.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Various utilities for neural networks.
3
+ """
4
+
5
+ from enum import Enum
6
+ import math
7
+ from typing import Optional
8
+
9
+ import torch as th
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+
13
+ import torch.nn.functional as F
14
+
15
+
16
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
17
+ class SiLU(nn.Module):
18
+ # @th.jit.script
19
+ def forward(self, x):
20
+ return x * th.sigmoid(x)
21
+
22
+
23
+ class GroupNorm32(nn.GroupNorm):
24
+ def forward(self, x):
25
+ return super().forward(x.float()).type(x.dtype)
26
+
27
+
28
+ def conv_nd(dims, *args, **kwargs):
29
+ """
30
+ Create a 1D, 2D, or 3D convolution module.
31
+ """
32
+ if dims == 1:
33
+ return nn.Conv1d(*args, **kwargs)
34
+ elif dims == 2:
35
+ return nn.Conv2d(*args, **kwargs)
36
+ elif dims == 3:
37
+ return nn.Conv3d(*args, **kwargs)
38
+ raise ValueError(f"unsupported dimensions: {dims}")
39
+
40
+
41
+ def linear(*args, **kwargs):
42
+ """
43
+ Create a linear module.
44
+ """
45
+ return nn.Linear(*args, **kwargs)
46
+
47
+
48
+ def avg_pool_nd(dims, *args, **kwargs):
49
+ """
50
+ Create a 1D, 2D, or 3D average pooling module.
51
+ """
52
+ if dims == 1:
53
+ return nn.AvgPool1d(*args, **kwargs)
54
+ elif dims == 2:
55
+ return nn.AvgPool2d(*args, **kwargs)
56
+ elif dims == 3:
57
+ return nn.AvgPool3d(*args, **kwargs)
58
+ raise ValueError(f"unsupported dimensions: {dims}")
59
+
60
+
61
+ def update_ema(target_params, source_params, rate=0.99):
62
+ """
63
+ Update target parameters to be closer to those of source parameters using
64
+ an exponential moving average.
65
+
66
+ :param target_params: the target parameter sequence.
67
+ :param source_params: the source parameter sequence.
68
+ :param rate: the EMA rate (closer to 1 means slower).
69
+ """
70
+ for targ, src in zip(target_params, source_params):
71
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
72
+
73
+
74
+ def zero_module(module):
75
+ """
76
+ Zero out the parameters of a module and return it.
77
+ """
78
+ for p in module.parameters():
79
+ p.detach().zero_()
80
+ return module
81
+
82
+
83
+ def scale_module(module, scale):
84
+ """
85
+ Scale the parameters of a module and return it.
86
+ """
87
+ for p in module.parameters():
88
+ p.detach().mul_(scale)
89
+ return module
90
+
91
+
92
+ def mean_flat(tensor):
93
+ """
94
+ Take the mean over all non-batch dimensions.
95
+ """
96
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
97
+
98
+
99
+ def normalization(channels):
100
+ """
101
+ Make a standard normalization layer.
102
+
103
+ :param channels: number of input channels.
104
+ :return: an nn.Module for normalization.
105
+ """
106
+ return GroupNorm32(min(32, channels), channels)
107
+
108
+
109
+ def timestep_embedding(timesteps, dim, max_period=10000):
110
+ """
111
+ Create sinusoidal timestep embeddings.
112
+
113
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
114
+ These may be fractional.
115
+ :param dim: the dimension of the output.
116
+ :param max_period: controls the minimum frequency of the embeddings.
117
+ :return: an [N x dim] Tensor of positional embeddings.
118
+ """
119
+ half = dim // 2
120
+ freqs = th.exp(-math.log(max_period) *
121
+ th.arange(start=0, end=half, dtype=th.float32) /
122
+ half).to(device=timesteps.device)
123
+ args = timesteps[:, None].float() * freqs[None]
124
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
125
+ if dim % 2:
126
+ embedding = th.cat(
127
+ [embedding, th.zeros_like(embedding[:, :1])], dim=-1)
128
+ return embedding
129
+
130
+
131
+ def torch_checkpoint(func, args, flag, preserve_rng_state=False):
132
+ # torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8
133
+ if flag:
134
+ return torch.utils.checkpoint.checkpoint(
135
+ func, *args, preserve_rng_state=preserve_rng_state)
136
+ else:
137
+ return func(*args)
model/unet.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from numbers import Number
4
+ from typing import NamedTuple, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch as th
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ from choices import *
11
+ from config_base import BaseConfig
12
+ from .blocks import *
13
+
14
+ from .nn import (conv_nd, linear, normalization, timestep_embedding,
15
+ torch_checkpoint, zero_module)
16
+
17
+
18
+ @dataclass
19
+ class BeatGANsUNetConfig(BaseConfig):
20
+ image_size: int = 64
21
+ in_channels: int = 3
22
+ # base channels, will be multiplied
23
+ model_channels: int = 64
24
+ # output of the unet
25
+ # suggest: 3
26
+ # you only need 6 if you also model the variance of the noise prediction (usually we use an analytical variance hence 3)
27
+ out_channels: int = 3
28
+ # how many repeating resblocks per resolution
29
+ # the decoding side would have "one more" resblock
30
+ # default: 2
31
+ num_res_blocks: int = 2
32
+ # you can also set the number of resblocks specifically for the input blocks
33
+ # default: None = above
34
+ num_input_res_blocks: int = None
35
+ # number of time embed channels and style channels
36
+ embed_channels: int = 512
37
+ # at what resolutions you want to do self-attention of the feature maps
38
+ # attentions generally improve performance
39
+ # default: [16]
40
+ # beatgans: [32, 16, 8]
41
+ attention_resolutions: Tuple[int] = (16, )
42
+ # number of time embed channels
43
+ time_embed_channels: int = None
44
+ # dropout applies to the resblocks (on feature maps)
45
+ dropout: float = 0.1
46
+ channel_mult: Tuple[int] = (1, 2, 4, 8)
47
+ input_channel_mult: Tuple[int] = None
48
+ conv_resample: bool = True
49
+ # always 2 = 2d conv
50
+ dims: int = 2
51
+ # don't use this, legacy from BeatGANs
52
+ num_classes: int = None
53
+ use_checkpoint: bool = False
54
+ # number of attention heads
55
+ num_heads: int = 1
56
+ # or specify the number of channels per attention head
57
+ num_head_channels: int = -1
58
+ # what's this?
59
+ num_heads_upsample: int = -1
60
+ # use resblock for upscale/downscale blocks (expensive)
61
+ # default: True (BeatGANs)
62
+ resblock_updown: bool = True
63
+ # never tried
64
+ use_new_attention_order: bool = False
65
+ resnet_two_cond: bool = False
66
+ resnet_cond_channels: int = None
67
+ # init the decoding conv layers with zero weights, this speeds up training
68
+ # default: True (BeattGANs)
69
+ resnet_use_zero_module: bool = True
70
+ # gradient checkpoint the attention operation
71
+ attn_checkpoint: bool = False
72
+
73
+ def make_model(self):
74
+ return BeatGANsUNetModel(self)
75
+
76
+
77
+ class BeatGANsUNetModel(nn.Module):
78
+ def __init__(self, conf: BeatGANsUNetConfig):
79
+ super().__init__()
80
+ self.conf = conf
81
+
82
+ if conf.num_heads_upsample == -1:
83
+ self.num_heads_upsample = conf.num_heads
84
+
85
+ self.dtype = th.float32
86
+
87
+ self.time_emb_channels = conf.time_embed_channels or conf.model_channels
88
+ self.time_embed = nn.Sequential(
89
+ linear(self.time_emb_channels, conf.embed_channels),
90
+ nn.SiLU(),
91
+ linear(conf.embed_channels, conf.embed_channels),
92
+ )
93
+
94
+ if conf.num_classes is not None:
95
+ self.label_emb = nn.Embedding(conf.num_classes,
96
+ conf.embed_channels)
97
+
98
+ ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
99
+ self.input_blocks = nn.ModuleList([
100
+ TimestepEmbedSequential(
101
+ conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
102
+ ])
103
+
104
+ kwargs = dict(
105
+ use_condition=True,
106
+ two_cond=conf.resnet_two_cond,
107
+ use_zero_module=conf.resnet_use_zero_module,
108
+ # style channels for the resnet block
109
+ cond_emb_channels=conf.resnet_cond_channels,
110
+ )
111
+
112
+ self._feature_size = ch
113
+
114
+ # input_block_chans = [ch]
115
+ input_block_chans = [[] for _ in range(len(conf.channel_mult))]
116
+ input_block_chans[0].append(ch)
117
+
118
+ # number of blocks at each resolution
119
+ self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
120
+ self.input_num_blocks[0] = 1
121
+ self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
122
+
123
+ ds = 1
124
+ resolution = conf.image_size
125
+ for level, mult in enumerate(conf.input_channel_mult
126
+ or conf.channel_mult):
127
+ for _ in range(conf.num_input_res_blocks or conf.num_res_blocks):
128
+ layers = [
129
+ ResBlockConfig(
130
+ ch,
131
+ conf.embed_channels,
132
+ conf.dropout,
133
+ out_channels=int(mult * conf.model_channels),
134
+ dims=conf.dims,
135
+ use_checkpoint=conf.use_checkpoint,
136
+ **kwargs,
137
+ ).make_model()
138
+ ]
139
+ ch = int(mult * conf.model_channels)
140
+ if resolution in conf.attention_resolutions:
141
+ layers.append(
142
+ AttentionBlock(
143
+ ch,
144
+ use_checkpoint=conf.use_checkpoint
145
+ or conf.attn_checkpoint,
146
+ num_heads=conf.num_heads,
147
+ num_head_channels=conf.num_head_channels,
148
+ use_new_attention_order=conf.
149
+ use_new_attention_order,
150
+ ))
151
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
152
+ self._feature_size += ch
153
+ # input_block_chans.append(ch)
154
+ input_block_chans[level].append(ch)
155
+ self.input_num_blocks[level] += 1
156
+ # print(input_block_chans)
157
+ if level != len(conf.channel_mult) - 1:
158
+ resolution //= 2
159
+ out_ch = ch
160
+ self.input_blocks.append(
161
+ TimestepEmbedSequential(
162
+ ResBlockConfig(
163
+ ch,
164
+ conf.embed_channels,
165
+ conf.dropout,
166
+ out_channels=out_ch,
167
+ dims=conf.dims,
168
+ use_checkpoint=conf.use_checkpoint,
169
+ down=True,
170
+ **kwargs,
171
+ ).make_model() if conf.
172
+ resblock_updown else Downsample(ch,
173
+ conf.conv_resample,
174
+ dims=conf.dims,
175
+ out_channels=out_ch)))
176
+ ch = out_ch
177
+ # input_block_chans.append(ch)
178
+ input_block_chans[level + 1].append(ch)
179
+ self.input_num_blocks[level + 1] += 1
180
+ ds *= 2
181
+ self._feature_size += ch
182
+
183
+ self.middle_block = TimestepEmbedSequential(
184
+ ResBlockConfig(
185
+ ch,
186
+ conf.embed_channels,
187
+ conf.dropout,
188
+ dims=conf.dims,
189
+ use_checkpoint=conf.use_checkpoint,
190
+ **kwargs,
191
+ ).make_model(),
192
+ AttentionBlock(
193
+ ch,
194
+ use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
195
+ num_heads=conf.num_heads,
196
+ num_head_channels=conf.num_head_channels,
197
+ use_new_attention_order=conf.use_new_attention_order,
198
+ ),
199
+ ResBlockConfig(
200
+ ch,
201
+ conf.embed_channels,
202
+ conf.dropout,
203
+ dims=conf.dims,
204
+ use_checkpoint=conf.use_checkpoint,
205
+ **kwargs,
206
+ ).make_model(),
207
+ )
208
+ self._feature_size += ch
209
+
210
+ self.output_blocks = nn.ModuleList([])
211
+ for level, mult in list(enumerate(conf.channel_mult))[::-1]:
212
+ for i in range(conf.num_res_blocks + 1):
213
+ # print(input_block_chans)
214
+ # ich = input_block_chans.pop()
215
+ try:
216
+ ich = input_block_chans[level].pop()
217
+ except IndexError:
218
+ # this happens only when num_res_block > num_enc_res_block
219
+ # we will not have enough lateral (skip) connecions for all decoder blocks
220
+ ich = 0
221
+ # print('pop:', ich)
222
+ layers = [
223
+ ResBlockConfig(
224
+ # only direct channels when gated
225
+ channels=ch + ich,
226
+ emb_channels=conf.embed_channels,
227
+ dropout=conf.dropout,
228
+ out_channels=int(conf.model_channels * mult),
229
+ dims=conf.dims,
230
+ use_checkpoint=conf.use_checkpoint,
231
+ # lateral channels are described here when gated
232
+ has_lateral=True if ich > 0 else False,
233
+ lateral_channels=None,
234
+ **kwargs,
235
+ ).make_model()
236
+ ]
237
+ ch = int(conf.model_channels * mult)
238
+ if resolution in conf.attention_resolutions:
239
+ layers.append(
240
+ AttentionBlock(
241
+ ch,
242
+ use_checkpoint=conf.use_checkpoint
243
+ or conf.attn_checkpoint,
244
+ num_heads=self.num_heads_upsample,
245
+ num_head_channels=conf.num_head_channels,
246
+ use_new_attention_order=conf.
247
+ use_new_attention_order,
248
+ ))
249
+ if level and i == conf.num_res_blocks:
250
+ resolution *= 2
251
+ out_ch = ch
252
+ layers.append(
253
+ ResBlockConfig(
254
+ ch,
255
+ conf.embed_channels,
256
+ conf.dropout,
257
+ out_channels=out_ch,
258
+ dims=conf.dims,
259
+ use_checkpoint=conf.use_checkpoint,
260
+ up=True,
261
+ **kwargs,
262
+ ).make_model() if (
263
+ conf.resblock_updown
264
+ ) else Upsample(ch,
265
+ conf.conv_resample,
266
+ dims=conf.dims,
267
+ out_channels=out_ch))
268
+ ds //= 2
269
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
270
+ self.output_num_blocks[level] += 1
271
+ self._feature_size += ch
272
+
273
+ # print(input_block_chans)
274
+ # print('inputs:', self.input_num_blocks)
275
+ # print('outputs:', self.output_num_blocks)
276
+
277
+ if conf.resnet_use_zero_module:
278
+ self.out = nn.Sequential(
279
+ normalization(ch),
280
+ nn.SiLU(),
281
+ zero_module(
282
+ conv_nd(conf.dims,
283
+ input_ch,
284
+ conf.out_channels,
285
+ 3,
286
+ padding=1)),
287
+ )
288
+ else:
289
+ self.out = nn.Sequential(
290
+ normalization(ch),
291
+ nn.SiLU(),
292
+ conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
293
+ )
294
+
295
+ def forward(self, x, t, y=None, **kwargs):
296
+ """
297
+ Apply the model to an input batch.
298
+
299
+ :param x: an [N x C x ...] Tensor of inputs.
300
+ :param timesteps: a 1-D batch of timesteps.
301
+ :param y: an [N] Tensor of labels, if class-conditional.
302
+ :return: an [N x C x ...] Tensor of outputs.
303
+ """
304
+ assert (y is not None) == (
305
+ self.conf.num_classes is not None
306
+ ), "must specify y if and only if the model is class-conditional"
307
+
308
+ # hs = []
309
+ hs = [[] for _ in range(len(self.conf.channel_mult))]
310
+ emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
311
+
312
+ if self.conf.num_classes is not None:
313
+ raise NotImplementedError()
314
+ # assert y.shape == (x.shape[0], )
315
+ # emb = emb + self.label_emb(y)
316
+
317
+ # new code supports input_num_blocks != output_num_blocks
318
+ h = x.type(self.dtype)
319
+ k = 0
320
+ for i in range(len(self.input_num_blocks)):
321
+ for j in range(self.input_num_blocks[i]):
322
+ h = self.input_blocks[k](h, emb=emb)
323
+ # print(i, j, h.shape)
324
+ hs[i].append(h)
325
+ k += 1
326
+ assert k == len(self.input_blocks)
327
+
328
+ h = self.middle_block(h, emb=emb)
329
+ k = 0
330
+ for i in range(len(self.output_num_blocks)):
331
+ for j in range(self.output_num_blocks[i]):
332
+ # take the lateral connection from the same layer (in reserve)
333
+ # until there is no more, use None
334
+ try:
335
+ lateral = hs[-i - 1].pop()
336
+ # print(i, j, lateral.shape)
337
+ except IndexError:
338
+ lateral = None
339
+ # print(i, j, lateral)
340
+ h = self.output_blocks[k](h, emb=emb, lateral=lateral)
341
+ k += 1
342
+
343
+ h = h.type(x.dtype)
344
+ pred = self.out(h)
345
+ return Return(pred=pred)
346
+
347
+
348
+ class Return(NamedTuple):
349
+ pred: th.Tensor
350
+
351
+
352
+ @dataclass
353
+ class BeatGANsEncoderConfig(BaseConfig):
354
+ image_size: int
355
+ in_channels: int
356
+ model_channels: int
357
+ out_hid_channels: int
358
+ out_channels: int
359
+ num_res_blocks: int
360
+ attention_resolutions: Tuple[int]
361
+ dropout: float = 0
362
+ channel_mult: Tuple[int] = (1, 2, 4, 8)
363
+ use_time_condition: bool = True
364
+ conv_resample: bool = True
365
+ dims: int = 2
366
+ use_checkpoint: bool = False
367
+ num_heads: int = 1
368
+ num_head_channels: int = -1
369
+ resblock_updown: bool = False
370
+ use_new_attention_order: bool = False
371
+ pool: str = 'adaptivenonzero'
372
+
373
+ def make_model(self):
374
+ return BeatGANsEncoderModel(self)
375
+
376
+
377
+ class BeatGANsEncoderModel(nn.Module):
378
+ """
379
+ The half UNet model with attention and timestep embedding.
380
+
381
+ For usage, see UNet.
382
+ """
383
+ def __init__(self, conf: BeatGANsEncoderConfig):
384
+ super().__init__()
385
+ self.conf = conf
386
+ self.dtype = th.float32
387
+
388
+ if conf.use_time_condition:
389
+ time_embed_dim = conf.model_channels * 4
390
+ self.time_embed = nn.Sequential(
391
+ linear(conf.model_channels, time_embed_dim),
392
+ nn.SiLU(),
393
+ linear(time_embed_dim, time_embed_dim),
394
+ )
395
+ else:
396
+ time_embed_dim = None
397
+
398
+ ch = int(conf.channel_mult[0] * conf.model_channels)
399
+ self.input_blocks = nn.ModuleList([
400
+ TimestepEmbedSequential(
401
+ conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
402
+ ])
403
+ self._feature_size = ch
404
+ input_block_chans = [ch]
405
+ ds = 1
406
+ resolution = conf.image_size
407
+ for level, mult in enumerate(conf.channel_mult):
408
+ for _ in range(conf.num_res_blocks):
409
+ layers = [
410
+ ResBlockConfig(
411
+ ch,
412
+ time_embed_dim,
413
+ conf.dropout,
414
+ out_channels=int(mult * conf.model_channels),
415
+ dims=conf.dims,
416
+ use_condition=conf.use_time_condition,
417
+ use_checkpoint=conf.use_checkpoint,
418
+ ).make_model()
419
+ ]
420
+ ch = int(mult * conf.model_channels)
421
+ if resolution in conf.attention_resolutions:
422
+ layers.append(
423
+ AttentionBlock(
424
+ ch,
425
+ use_checkpoint=conf.use_checkpoint,
426
+ num_heads=conf.num_heads,
427
+ num_head_channels=conf.num_head_channels,
428
+ use_new_attention_order=conf.
429
+ use_new_attention_order,
430
+ ))
431
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
432
+ self._feature_size += ch
433
+ input_block_chans.append(ch)
434
+ if level != len(conf.channel_mult) - 1:
435
+ resolution //= 2
436
+ out_ch = ch
437
+ self.input_blocks.append(
438
+ TimestepEmbedSequential(
439
+ ResBlockConfig(
440
+ ch,
441
+ time_embed_dim,
442
+ conf.dropout,
443
+ out_channels=out_ch,
444
+ dims=conf.dims,
445
+ use_condition=conf.use_time_condition,
446
+ use_checkpoint=conf.use_checkpoint,
447
+ down=True,
448
+ ).make_model() if (
449
+ conf.resblock_updown
450
+ ) else Downsample(ch,
451
+ conf.conv_resample,
452
+ dims=conf.dims,
453
+ out_channels=out_ch)))
454
+ ch = out_ch
455
+ input_block_chans.append(ch)
456
+ ds *= 2
457
+ self._feature_size += ch
458
+
459
+ self.middle_block = TimestepEmbedSequential(
460
+ ResBlockConfig(
461
+ ch,
462
+ time_embed_dim,
463
+ conf.dropout,
464
+ dims=conf.dims,
465
+ use_condition=conf.use_time_condition,
466
+ use_checkpoint=conf.use_checkpoint,
467
+ ).make_model(),
468
+ AttentionBlock(
469
+ ch,
470
+ use_checkpoint=conf.use_checkpoint,
471
+ num_heads=conf.num_heads,
472
+ num_head_channels=conf.num_head_channels,
473
+ use_new_attention_order=conf.use_new_attention_order,
474
+ ),
475
+ ResBlockConfig(
476
+ ch,
477
+ time_embed_dim,
478
+ conf.dropout,
479
+ dims=conf.dims,
480
+ use_condition=conf.use_time_condition,
481
+ use_checkpoint=conf.use_checkpoint,
482
+ ).make_model(),
483
+ )
484
+ self._feature_size += ch
485
+ if conf.pool == "adaptivenonzero":
486
+ self.out = nn.Sequential(
487
+ normalization(ch),
488
+ nn.SiLU(),
489
+ nn.AdaptiveAvgPool2d((1, 1)),
490
+ conv_nd(conf.dims, ch, conf.out_channels, 1),
491
+ nn.Flatten(),
492
+ )
493
+ else:
494
+ raise NotImplementedError(f"Unexpected {conf.pool} pooling")
495
+
496
+ def forward(self, x, t=None, return_2d_feature=False):
497
+ """
498
+ Apply the model to an input batch.
499
+
500
+ :param x: an [N x C x ...] Tensor of inputs.
501
+ :param timesteps: a 1-D batch of timesteps.
502
+ :return: an [N x K] Tensor of outputs.
503
+ """
504
+ if self.conf.use_time_condition:
505
+ emb = self.time_embed(timestep_embedding(t, self.model_channels))
506
+ else:
507
+ emb = None
508
+
509
+ results = []
510
+ h = x.type(self.dtype)
511
+ for module in self.input_blocks:
512
+ h = module(h, emb=emb)
513
+ if self.conf.pool.startswith("spatial"):
514
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
515
+ h = self.middle_block(h, emb=emb)
516
+ if self.conf.pool.startswith("spatial"):
517
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
518
+ h = th.cat(results, axis=-1)
519
+ else:
520
+ h = h.type(x.dtype)
521
+
522
+ h_2d = h
523
+ h = self.out(h)
524
+
525
+ if return_2d_feature:
526
+ return h, h_2d
527
+ else:
528
+ return h
529
+
530
+ def forward_flatten(self, x):
531
+ """
532
+ transform the last 2d feature into a flatten vector
533
+ """
534
+ h = self.out(x)
535
+ return h
536
+
537
+
538
+ class SuperResModel(BeatGANsUNetModel):
539
+ """
540
+ A UNetModel that performs super-resolution.
541
+
542
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
543
+ """
544
+ def __init__(self, image_size, in_channels, *args, **kwargs):
545
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
546
+
547
+ def forward(self, x, timesteps, low_res=None, **kwargs):
548
+ _, _, new_height, new_width = x.shape
549
+ upsampled = F.interpolate(low_res, (new_height, new_width),
550
+ mode="bilinear")
551
+ x = th.cat([x, upsampled], dim=1)
552
+ return super().forward(x, timesteps, **kwargs)
model/unet_autoenc.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.nn.functional import silu
6
+
7
+ from .latentnet import *
8
+ from .unet import *
9
+ from choices import *
10
+
11
+
12
+ @dataclass
13
+ class BeatGANsAutoencConfig(BeatGANsUNetConfig):
14
+ # number of style channels
15
+ enc_out_channels: int = 512
16
+ enc_attn_resolutions: Tuple[int] = None
17
+ enc_pool: str = 'depthconv'
18
+ enc_num_res_block: int = 2
19
+ enc_channel_mult: Tuple[int] = None
20
+ enc_grad_checkpoint: bool = False
21
+ latent_net_conf: MLPSkipNetConfig = None
22
+
23
+ def make_model(self):
24
+ return BeatGANsAutoencModel(self)
25
+
26
+
27
+ class BeatGANsAutoencModel(BeatGANsUNetModel):
28
+ def __init__(self, conf: BeatGANsAutoencConfig):
29
+ super().__init__(conf)
30
+ self.conf = conf
31
+
32
+ # having only time, cond
33
+ self.time_embed = TimeStyleSeperateEmbed(
34
+ time_channels=conf.model_channels,
35
+ time_out_channels=conf.embed_channels,
36
+ )
37
+
38
+ self.encoder = BeatGANsEncoderConfig(
39
+ image_size=conf.image_size,
40
+ in_channels=conf.in_channels,
41
+ model_channels=conf.model_channels,
42
+ out_hid_channels=conf.enc_out_channels,
43
+ out_channels=conf.enc_out_channels,
44
+ num_res_blocks=conf.enc_num_res_block,
45
+ attention_resolutions=(conf.enc_attn_resolutions
46
+ or conf.attention_resolutions),
47
+ dropout=conf.dropout,
48
+ channel_mult=conf.enc_channel_mult or conf.channel_mult,
49
+ use_time_condition=False,
50
+ conv_resample=conf.conv_resample,
51
+ dims=conf.dims,
52
+ use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint,
53
+ num_heads=conf.num_heads,
54
+ num_head_channels=conf.num_head_channels,
55
+ resblock_updown=conf.resblock_updown,
56
+ use_new_attention_order=conf.use_new_attention_order,
57
+ pool=conf.enc_pool,
58
+ ).make_model()
59
+
60
+ if conf.latent_net_conf is not None:
61
+ self.latent_net = conf.latent_net_conf.make_model()
62
+
63
+ def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
64
+ """
65
+ Reparameterization trick to sample from N(mu, var) from
66
+ N(0,1).
67
+ :param mu: (Tensor) Mean of the latent Gaussian [B x D]
68
+ :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
69
+ :return: (Tensor) [B x D]
70
+ """
71
+ assert self.conf.is_stochastic
72
+ std = torch.exp(0.5 * logvar)
73
+ eps = torch.randn_like(std)
74
+ return eps * std + mu
75
+
76
+ def sample_z(self, n: int, device):
77
+ assert self.conf.is_stochastic
78
+ return torch.randn(n, self.conf.enc_out_channels, device=device)
79
+
80
+ def noise_to_cond(self, noise: Tensor):
81
+ raise NotImplementedError()
82
+ assert self.conf.noise_net_conf is not None
83
+ return self.noise_net.forward(noise)
84
+
85
+ def encode(self, x):
86
+ cond = self.encoder.forward(x)
87
+ return {'cond': cond}
88
+
89
+ @property
90
+ def stylespace_sizes(self):
91
+ modules = list(self.input_blocks.modules()) + list(
92
+ self.middle_block.modules()) + list(self.output_blocks.modules())
93
+ sizes = []
94
+ for module in modules:
95
+ if isinstance(module, ResBlock):
96
+ linear = module.cond_emb_layers[-1]
97
+ sizes.append(linear.weight.shape[0])
98
+ return sizes
99
+
100
+ def encode_stylespace(self, x, return_vector: bool = True):
101
+ """
102
+ encode to style space
103
+ """
104
+ modules = list(self.input_blocks.modules()) + list(
105
+ self.middle_block.modules()) + list(self.output_blocks.modules())
106
+ # (n, c)
107
+ cond = self.encoder.forward(x)
108
+ S = []
109
+ for module in modules:
110
+ if isinstance(module, ResBlock):
111
+ # (n, c')
112
+ s = module.cond_emb_layers.forward(cond)
113
+ S.append(s)
114
+
115
+ if return_vector:
116
+ # (n, sum_c)
117
+ return torch.cat(S, dim=1)
118
+ else:
119
+ return S
120
+
121
+ def forward(self,
122
+ x,
123
+ t,
124
+ y=None,
125
+ x_start=None,
126
+ cond=None,
127
+ style=None,
128
+ noise=None,
129
+ t_cond=None,
130
+ **kwargs):
131
+ """
132
+ Apply the model to an input batch.
133
+
134
+ Args:
135
+ x_start: the original image to encode
136
+ cond: output of the encoder
137
+ noise: random noise (to predict the cond)
138
+ """
139
+
140
+ if t_cond is None:
141
+ t_cond = t
142
+
143
+ if noise is not None:
144
+ # if the noise is given, we predict the cond from noise
145
+ cond = self.noise_to_cond(noise)
146
+
147
+ if cond is None:
148
+ if x is not None:
149
+ assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
150
+
151
+ tmp = self.encode(x_start)
152
+ cond = tmp['cond']
153
+
154
+ if t is not None:
155
+ _t_emb = timestep_embedding(t, self.conf.model_channels)
156
+ _t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
157
+ else:
158
+ # this happens when training only autoenc
159
+ _t_emb = None
160
+ _t_cond_emb = None
161
+
162
+ if self.conf.resnet_two_cond:
163
+ res = self.time_embed.forward(
164
+ time_emb=_t_emb,
165
+ cond=cond,
166
+ time_cond_emb=_t_cond_emb,
167
+ )
168
+ else:
169
+ raise NotImplementedError()
170
+
171
+ if self.conf.resnet_two_cond:
172
+ # two cond: first = time emb, second = cond_emb
173
+ emb = res.time_emb
174
+ cond_emb = res.emb
175
+ else:
176
+ # one cond = combined of both time and cond
177
+ emb = res.emb
178
+ cond_emb = None
179
+
180
+ # override the style if given
181
+ style = style or res.style
182
+
183
+ assert (y is not None) == (
184
+ self.conf.num_classes is not None
185
+ ), "must specify y if and only if the model is class-conditional"
186
+
187
+ if self.conf.num_classes is not None:
188
+ raise NotImplementedError()
189
+ # assert y.shape == (x.shape[0], )
190
+ # emb = emb + self.label_emb(y)
191
+
192
+ # where in the model to supply time conditions
193
+ enc_time_emb = emb
194
+ mid_time_emb = emb
195
+ dec_time_emb = emb
196
+ # where in the model to supply style conditions
197
+ enc_cond_emb = cond_emb
198
+ mid_cond_emb = cond_emb
199
+ dec_cond_emb = cond_emb
200
+
201
+ # hs = []
202
+ hs = [[] for _ in range(len(self.conf.channel_mult))]
203
+
204
+ if x is not None:
205
+ h = x.type(self.dtype)
206
+
207
+ # input blocks
208
+ k = 0
209
+ for i in range(len(self.input_num_blocks)):
210
+ for j in range(self.input_num_blocks[i]):
211
+ h = self.input_blocks[k](h,
212
+ emb=enc_time_emb,
213
+ cond=enc_cond_emb)
214
+
215
+ # print(i, j, h.shape)
216
+ hs[i].append(h)
217
+ k += 1
218
+ assert k == len(self.input_blocks)
219
+
220
+ # middle blocks
221
+ h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
222
+ else:
223
+ # no lateral connections
224
+ # happens when training only the autonecoder
225
+ h = None
226
+ hs = [[] for _ in range(len(self.conf.channel_mult))]
227
+
228
+ # output blocks
229
+ k = 0
230
+ for i in range(len(self.output_num_blocks)):
231
+ for j in range(self.output_num_blocks[i]):
232
+ # take the lateral connection from the same layer (in reserve)
233
+ # until there is no more, use None
234
+ try:
235
+ lateral = hs[-i - 1].pop()
236
+ # print(i, j, lateral.shape)
237
+ except IndexError:
238
+ lateral = None
239
+ # print(i, j, lateral)
240
+
241
+ h = self.output_blocks[k](h,
242
+ emb=dec_time_emb,
243
+ cond=dec_cond_emb,
244
+ lateral=lateral)
245
+ k += 1
246
+
247
+ pred = self.out(h)
248
+ return AutoencReturn(pred=pred, cond=cond)
249
+
250
+
251
+ class AutoencReturn(NamedTuple):
252
+ pred: Tensor
253
+ cond: Tensor = None
254
+
255
+
256
+ class EmbedReturn(NamedTuple):
257
+ # style and time
258
+ emb: Tensor = None
259
+ # time only
260
+ time_emb: Tensor = None
261
+ # style only (but could depend on time)
262
+ style: Tensor = None
263
+
264
+
265
+ class TimeStyleSeperateEmbed(nn.Module):
266
+ # embed only style
267
+ def __init__(self, time_channels, time_out_channels):
268
+ super().__init__()
269
+ self.time_embed = nn.Sequential(
270
+ linear(time_channels, time_out_channels),
271
+ nn.SiLU(),
272
+ linear(time_out_channels, time_out_channels),
273
+ )
274
+ self.style = nn.Identity()
275
+
276
+ def forward(self, time_emb=None, cond=None, **kwargs):
277
+ if time_emb is None:
278
+ # happens with autoenc training mode
279
+ time_emb = None
280
+ else:
281
+ time_emb = self.time_embed(time_emb)
282
+ style = self.style(cond)
283
+ return EmbedReturn(emb=style, time_emb=time_emb, style=style)
predict.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pre-download the weights for 256 resolution model to checkpoints/ffhq256_autoenc and checkpoints/ffhq256_autoenc_cls
2
+ # wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
3
+ # bunzip2 shape_predictor_68_face_landmarks.dat.bz2
4
+
5
+ import os
6
+ import torch
7
+ from torchvision.utils import save_image
8
+ import tempfile
9
+ from templates import *
10
+ from templates_cls import *
11
+ from experiment_classifier import ClsModel
12
+ from align import LandmarksDetector, image_align
13
+ from cog import BasePredictor, Path, Input, BaseModel
14
+
15
+
16
+ class ModelOutput(BaseModel):
17
+ image: Path
18
+
19
+
20
+ class Predictor(BasePredictor):
21
+ def setup(self):
22
+ self.aligned_dir = "aligned"
23
+ os.makedirs(self.aligned_dir, exist_ok=True)
24
+ self.device = "cuda:0"
25
+
26
+ # Model Initialization
27
+ model_config = ffhq256_autoenc()
28
+ self.model = LitModel(model_config)
29
+ state = torch.load("checkpoints/ffhq256_autoenc/last.ckpt", map_location="cpu")
30
+ self.model.load_state_dict(state["state_dict"], strict=False)
31
+ self.model.ema_model.eval()
32
+ self.model.ema_model.to(self.device)
33
+
34
+ # Classifier Initialization
35
+ classifier_config = ffhq256_autoenc_cls()
36
+ classifier_config.pretrain = None # a bit faster
37
+ self.classifier = ClsModel(classifier_config)
38
+ state_class = torch.load(
39
+ "checkpoints/ffhq256_autoenc_cls/last.ckpt", map_location="cpu"
40
+ )
41
+ print("latent step:", state_class["global_step"])
42
+ self.classifier.load_state_dict(state_class["state_dict"], strict=False)
43
+ self.classifier.to(self.device)
44
+
45
+ self.landmarks_detector = LandmarksDetector(
46
+ "shape_predictor_68_face_landmarks.dat"
47
+ )
48
+
49
+ def predict(
50
+ self,
51
+ image: Path = Input(
52
+ description="Input image for face manipulation. Image will be aligned and cropped, "
53
+ "output aligned and manipulated images.",
54
+ ),
55
+ target_class: str = Input(
56
+ default="Bangs",
57
+ choices=[
58
+ "5_o_Clock_Shadow",
59
+ "Arched_Eyebrows",
60
+ "Attractive",
61
+ "Bags_Under_Eyes",
62
+ "Bald",
63
+ "Bangs",
64
+ "Big_Lips",
65
+ "Big_Nose",
66
+ "Black_Hair",
67
+ "Blond_Hair",
68
+ "Blurry",
69
+ "Brown_Hair",
70
+ "Bushy_Eyebrows",
71
+ "Chubby",
72
+ "Double_Chin",
73
+ "Eyeglasses",
74
+ "Goatee",
75
+ "Gray_Hair",
76
+ "Heavy_Makeup",
77
+ "High_Cheekbones",
78
+ "Male",
79
+ "Mouth_Slightly_Open",
80
+ "Mustache",
81
+ "Narrow_Eyes",
82
+ "Beard",
83
+ "Oval_Face",
84
+ "Pale_Skin",
85
+ "Pointy_Nose",
86
+ "Receding_Hairline",
87
+ "Rosy_Cheeks",
88
+ "Sideburns",
89
+ "Smiling",
90
+ "Straight_Hair",
91
+ "Wavy_Hair",
92
+ "Wearing_Earrings",
93
+ "Wearing_Hat",
94
+ "Wearing_Lipstick",
95
+ "Wearing_Necklace",
96
+ "Wearing_Necktie",
97
+ "Young",
98
+ ],
99
+ description="Choose manipulation direction.",
100
+ ),
101
+ manipulation_amplitude: float = Input(
102
+ default=0.3,
103
+ ge=-0.5,
104
+ le=0.5,
105
+ description="When set too strong it would result in artifact as it could dominate the original image information.",
106
+ ),
107
+ T_step: int = Input(
108
+ default=100,
109
+ choices=[50, 100, 125, 200, 250, 500],
110
+ description="Number of step for generation.",
111
+ ),
112
+ T_inv: int = Input(default=200, choices=[50, 100, 125, 200, 250, 500]),
113
+ ) -> List[ModelOutput]:
114
+
115
+ img_size = 256
116
+ print("Aligning image...")
117
+ for i, face_landmarks in enumerate(
118
+ self.landmarks_detector.get_landmarks(str(image)), start=1
119
+ ):
120
+ image_align(str(image), f"{self.aligned_dir}/aligned.png", face_landmarks)
121
+
122
+ data = ImageDataset(
123
+ self.aligned_dir,
124
+ image_size=img_size,
125
+ exts=["jpg", "jpeg", "JPG", "png"],
126
+ do_augment=False,
127
+ )
128
+
129
+ print("Encoding and Manipulating the aligned image...")
130
+ cls_manipulation_amplitude = manipulation_amplitude
131
+ interpreted_target_class = target_class
132
+ if (
133
+ target_class not in CelebAttrDataset.id_to_cls
134
+ and f"No_{target_class}" in CelebAttrDataset.id_to_cls
135
+ ):
136
+ cls_manipulation_amplitude = -manipulation_amplitude
137
+ interpreted_target_class = f"No_{target_class}"
138
+
139
+ batch = data[0]["img"][None]
140
+
141
+ semantic_latent = self.model.encode(batch.to(self.device))
142
+ stochastic_latent = self.model.encode_stochastic(
143
+ batch.to(self.device), semantic_latent, T=T_inv
144
+ )
145
+
146
+ cls_id = CelebAttrDataset.cls_to_id[interpreted_target_class]
147
+ class_direction = self.classifier.classifier.weight[cls_id]
148
+ normalized_class_direction = F.normalize(class_direction[None, :], dim=1)
149
+
150
+ normalized_semantic_latent = self.classifier.normalize(semantic_latent)
151
+ normalized_manipulation_amp = cls_manipulation_amplitude * math.sqrt(512)
152
+ normalized_manipulated_semantic_latent = (
153
+ normalized_semantic_latent
154
+ + normalized_manipulation_amp * normalized_class_direction
155
+ )
156
+
157
+ manipulated_semantic_latent = self.classifier.denormalize(
158
+ normalized_manipulated_semantic_latent
159
+ )
160
+
161
+ # Render Manipulated image
162
+ manipulated_img = self.model.render(
163
+ stochastic_latent, manipulated_semantic_latent, T=T_step
164
+ )[0]
165
+ original_img = data[0]["img"]
166
+
167
+ model_output = []
168
+ out_path = Path(tempfile.mkdtemp()) / "original_aligned.png"
169
+ save_image(convert2rgb(original_img), str(out_path))
170
+ model_output.append(ModelOutput(image=out_path))
171
+
172
+ out_path = Path(tempfile.mkdtemp()) / "manipulated_img.png"
173
+ save_image(convert2rgb(manipulated_img, adjust_scale=False), str(out_path))
174
+ model_output.append(ModelOutput(image=out_path))
175
+ return model_output
176
+
177
+
178
+ def convert2rgb(img, adjust_scale=True):
179
+ convert_img = torch.tensor(img)
180
+ if adjust_scale:
181
+ convert_img = (convert_img + 1) / 2
182
+ return convert_img.cpu()
renderer.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import *
2
+
3
+ from torch.cuda import amp
4
+
5
+
6
+ def render_uncondition(conf: TrainConfig,
7
+ model: BeatGANsAutoencModel,
8
+ x_T,
9
+ sampler: Sampler,
10
+ latent_sampler: Sampler,
11
+ conds_mean=None,
12
+ conds_std=None,
13
+ clip_latent_noise: bool = False):
14
+ device = x_T.device
15
+ if conf.train_mode == TrainMode.diffusion:
16
+ assert conf.model_type.can_sample()
17
+ return sampler.sample(model=model, noise=x_T)
18
+ elif conf.train_mode.is_latent_diffusion():
19
+ model: BeatGANsAutoencModel
20
+ if conf.train_mode == TrainMode.latent_diffusion:
21
+ latent_noise = torch.randn(len(x_T), conf.style_ch, device=device)
22
+ else:
23
+ raise NotImplementedError()
24
+
25
+ if clip_latent_noise:
26
+ latent_noise = latent_noise.clip(-1, 1)
27
+
28
+ cond = latent_sampler.sample(
29
+ model=model.latent_net,
30
+ noise=latent_noise,
31
+ clip_denoised=conf.latent_clip_sample,
32
+ )
33
+
34
+ if conf.latent_znormalize:
35
+ cond = cond * conds_std.to(device) + conds_mean.to(device)
36
+
37
+ # the diffusion on the model
38
+ return sampler.sample(model=model, noise=x_T, cond=cond)
39
+ else:
40
+ raise NotImplementedError()
41
+
42
+
43
+ def render_condition(
44
+ conf: TrainConfig,
45
+ model: BeatGANsAutoencModel,
46
+ x_T,
47
+ sampler: Sampler,
48
+ x_start=None,
49
+ cond=None,
50
+ ):
51
+ if conf.train_mode == TrainMode.diffusion:
52
+ assert conf.model_type.has_autoenc()
53
+ # returns {'cond', 'cond2'}
54
+ if cond is None:
55
+ cond = model.encode(x_start)
56
+ return sampler.sample(model=model,
57
+ noise=x_T,
58
+ model_kwargs={'cond': cond})
59
+ else:
60
+ raise NotImplementedError()