AmitIsraeli
commited on
Commit
•
fc8623e
1
Parent(s):
ded2f46
add checkpoint VAR trained on pops
Browse files- VAR/.gitignore +20 -0
- VAR/LICENSE +21 -0
- VAR/README.md +169 -0
- VAR/demo_sample.ipynb +176 -0
- VAR/dist.py +211 -0
- VAR/models/__init__.py +39 -0
- VAR/models/basic_vae.py +226 -0
- VAR/models/basic_var.py +174 -0
- VAR/models/helpers.py +59 -0
- VAR/models/quant.py +281 -0
- VAR/models/var.py +323 -0
- VAR/models/vqvae.py +95 -0
- VAR/requirements.txt +8 -0
- VAR/train.py +335 -0
- VAR/trainer.py +201 -0
- VAR/utils/amp_sc.py +89 -0
- VAR/utils/arg_util.py +284 -0
- VAR/utils/data.py +54 -0
- VAR/utils/data_sampler.py +103 -0
- VAR/utils/lr_control.py +108 -0
- VAR/utils/misc.py +381 -0
- infrance_example.py +19 -0
- model-step-step=32000.ckpt +3 -0
- vae_ch160v4096z32.pth +3 -0
VAR/.gitignore
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.swp
|
2 |
+
**/__pycache__/**
|
3 |
+
**/.ipynb_checkpoints/**
|
4 |
+
.DS_Store
|
5 |
+
.idea/*
|
6 |
+
.vscode/*
|
7 |
+
llava/
|
8 |
+
_vis_cached/
|
9 |
+
_auto_*
|
10 |
+
ckpt/
|
11 |
+
log/
|
12 |
+
tb*/
|
13 |
+
img*/
|
14 |
+
local_output*
|
15 |
+
*.pth
|
16 |
+
*.pth.tar
|
17 |
+
*.ckpt
|
18 |
+
*.log
|
19 |
+
*.txt
|
20 |
+
*.ipynb
|
VAR/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 FoundationVision
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
VAR/README.md
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VAR: a new visual generation method elevates GPT-style models beyond diffusion🚀 & Scaling laws observed📈
|
2 |
+
|
3 |
+
<div align="center">
|
4 |
+
|
5 |
+
[![demo platform](https://img.shields.io/badge/Play%20with%20VAR%21-VAR%20demo%20platform-lightblue)](https://var.vision/demo)
|
6 |
+
[![arXiv](https://img.shields.io/badge/arXiv%20paper-2404.02905-b31b1b.svg)](https://arxiv.org/abs/2404.02905)
|
7 |
+
[![huggingface weights](https://img.shields.io/badge/%F0%9F%A4%97%20Weights-FoundationVision/var-yellow)](https://huggingface.co/FoundationVision/var)
|
8 |
+
[![SOTA](https://img.shields.io/badge/State%20of%20the%20Art-Image%20Generation%20on%20ImageNet%20%28AR%29-32B1B4?logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iNjA2IiBoZWlnaHQ9IjYwNiIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIiB4bWxuczp4bGluaz0iaHR0cDovL3d3dy53My5vcmcvMTk5OS94bGluayIgb3ZlcmZsb3c9ImhpZGRlbiI%2BPGRlZnM%2BPGNsaXBQYXRoIGlkPSJjbGlwMCI%2BPHJlY3QgeD0iLTEiIHk9Ii0xIiB3aWR0aD0iNjA2IiBoZWlnaHQ9IjYwNiIvPjwvY2xpcFBhdGg%2BPC9kZWZzPjxnIGNsaXAtcGF0aD0idXJsKCNjbGlwMCkiIHRyYW5zZm9ybT0idHJhbnNsYXRlKDEgMSkiPjxyZWN0IHg9IjUyOSIgeT0iNjYiIHdpZHRoPSI1NiIgaGVpZ2h0PSI0NzMiIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSIxOSIgeT0iNjYiIHdpZHRoPSI1NyIgaGVpZ2h0PSI0NzMiIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSIyNzQiIHk9IjE1MSIgd2lkdGg9IjU3IiBoZWlnaHQ9IjMwMiIgZmlsbD0iIzQ0RjJGNiIvPjxyZWN0IHg9IjEwNCIgeT0iMTUxIiB3aWR0aD0iNTciIGhlaWdodD0iMzAyIiBmaWxsPSIjNDRGMkY2Ii8%2BPHJlY3QgeD0iNDQ0IiB5PSIxNTEiIHdpZHRoPSI1NyIgaGVpZ2h0PSIzMDIiIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSIzNTkiIHk9IjE3MCIgd2lkdGg9IjU2IiBoZWlnaHQ9IjI2NCIgZmlsbD0iIzQ0RjJGNiIvPjxyZWN0IHg9IjE4OCIgeT0iMTcwIiB3aWR0aD0iNTciIGhlaWdodD0iMjY0IiBmaWxsPSIjNDRGMkY2Ii8%2BPHJlY3QgeD0iNzYiIHk9IjY2IiB3aWR0aD0iNDciIGhlaWdodD0iNTciIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSI0ODIiIHk9IjY2IiB3aWR0aD0iNDciIGhlaWdodD0iNTciIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSI3NiIgeT0iNDgyIiB3aWR0aD0iNDciIGhlaWdodD0iNTciIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSI0ODIiIHk9IjQ4MiIgd2lkdGg9IjQ3IiBoZWlnaHQ9IjU3IiBmaWxsPSIjNDRGMkY2Ii8%2BPC9nPjwvc3ZnPg%3D%3D)](https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?tag_filter=485&p=visual-autoregressive-modeling-scalable-image)
|
9 |
+
|
10 |
+
|
11 |
+
</div>
|
12 |
+
<p align="center" style="font-size: larger;">
|
13 |
+
<a href="https://arxiv.org/abs/2404.02905">Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction</a>
|
14 |
+
</p>
|
15 |
+
|
16 |
+
<div>
|
17 |
+
<p align="center" style="font-size: larger;">
|
18 |
+
<strong>NeurIPS 2024 Oral</strong>
|
19 |
+
</p>
|
20 |
+
</div>
|
21 |
+
|
22 |
+
<p align="center">
|
23 |
+
<img src="https://github.com/FoundationVision/VAR/assets/39692511/9850df90-20b1-4f29-8592-e3526d16d755" width=95%>
|
24 |
+
<p>
|
25 |
+
|
26 |
+
<br>
|
27 |
+
|
28 |
+
## News
|
29 |
+
|
30 |
+
* **2024-09:** VAR is accepted as **NeurIPS 2024 Oral** Presentation.
|
31 |
+
* **2024-04:** [Visual AutoRegressive modeling](https://github.com/FoundationVision/VAR) is released.
|
32 |
+
|
33 |
+
## 🕹️ Try and Play with VAR!
|
34 |
+
|
35 |
+
We provide a [demo website](https://var.vision/demo) for you to play with VAR models and generate images interactively. Enjoy the fun of visual autoregressive modeling!
|
36 |
+
|
37 |
+
We also provide [demo_sample.ipynb](demo_sample.ipynb) for you to see more technical details about VAR.
|
38 |
+
|
39 |
+
[//]: # (<p align="center">)
|
40 |
+
[//]: # (<img src="https://user-images.githubusercontent.com/39692511/226376648-3f28a1a6-275d-4f88-8f3e-cd1219882488.png" width=50%)
|
41 |
+
[//]: # (<p>)
|
42 |
+
|
43 |
+
|
44 |
+
## What's New?
|
45 |
+
|
46 |
+
### 🔥 Introducing VAR: a new paradigm in autoregressive visual generation✨:
|
47 |
+
|
48 |
+
Visual Autoregressive Modeling (VAR) redefines the autoregressive learning on images as coarse-to-fine "next-scale prediction" or "next-resolution prediction", diverging from the standard raster-scan "next-token prediction".
|
49 |
+
|
50 |
+
<p align="center">
|
51 |
+
<img src="https://github.com/FoundationVision/VAR/assets/39692511/3e12655c-37dc-4528-b923-ec6c4cfef178" width=93%>
|
52 |
+
<p>
|
53 |
+
|
54 |
+
### 🔥 For the first time, GPT-style autoregressive models surpass diffusion models🚀:
|
55 |
+
<p align="center">
|
56 |
+
<img src="https://github.com/FoundationVision/VAR/assets/39692511/cc30b043-fa4e-4d01-a9b1-e50650d5675d" width=55%>
|
57 |
+
<p>
|
58 |
+
|
59 |
+
|
60 |
+
### 🔥 Discovering power-law Scaling Laws in VAR transformers📈:
|
61 |
+
|
62 |
+
|
63 |
+
<p align="center">
|
64 |
+
<img src="https://github.com/FoundationVision/VAR/assets/39692511/c35fb56e-896e-4e4b-9fb9-7a1c38513804" width=85%>
|
65 |
+
<p>
|
66 |
+
<p align="center">
|
67 |
+
<img src="https://github.com/FoundationVision/VAR/assets/39692511/91d7b92c-8fc3-44d9-8fb4-73d6cdb8ec1e" width=85%>
|
68 |
+
<p>
|
69 |
+
|
70 |
+
|
71 |
+
### 🔥 Zero-shot generalizability🛠️:
|
72 |
+
|
73 |
+
<p align="center">
|
74 |
+
<img src="https://github.com/FoundationVision/VAR/assets/39692511/a54a4e52-6793-4130-bae2-9e459a08e96a" width=70%>
|
75 |
+
<p>
|
76 |
+
|
77 |
+
#### For a deep dive into our analyses, discussions, and evaluations, check out our [paper](https://arxiv.org/abs/2404.02905).
|
78 |
+
|
79 |
+
|
80 |
+
## VAR zoo
|
81 |
+
We provide VAR models for you to play with, which are on <a href='https://huggingface.co/FoundationVision/var'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Huggingface-FoundationVision/var-yellow'></a> or can be downloaded from the following links:
|
82 |
+
|
83 |
+
| model | reso. | FID | rel. cost | #params | HF weights🤗 |
|
84 |
+
|:----------:|:-----:|:--------:|:---------:|:-------:|:------------------------------------------------------------------------------------|
|
85 |
+
| VAR-d16 | 256 | 3.55 | 0.4 | 310M | [var_d16.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d16.pth) |
|
86 |
+
| VAR-d20 | 256 | 2.95 | 0.5 | 600M | [var_d20.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d20.pth) |
|
87 |
+
| VAR-d24 | 256 | 2.33 | 0.6 | 1.0B | [var_d24.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d24.pth) |
|
88 |
+
| VAR-d30 | 256 | 1.97 | 1 | 2.0B | [var_d30.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d30.pth) |
|
89 |
+
| VAR-d30-re | 256 | **1.80** | 1 | 2.0B | [var_d30.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d30.pth) |
|
90 |
+
|
91 |
+
You can load these models to generate images via the codes in [demo_sample.ipynb](demo_sample.ipynb). Note: you need to download [vae_ch160v4096z32.pth](https://huggingface.co/FoundationVision/var/resolve/main/vae_ch160v4096z32.pth) first.
|
92 |
+
|
93 |
+
|
94 |
+
## Installation
|
95 |
+
|
96 |
+
1. Install `torch>=2.0.0`.
|
97 |
+
2. Install other pip packages via `pip3 install -r requirements.txt`.
|
98 |
+
3. Prepare the [ImageNet](http://image-net.org/) dataset
|
99 |
+
<details>
|
100 |
+
<summary> assume the ImageNet is in `/path/to/imagenet`. It should be like this:</summary>
|
101 |
+
|
102 |
+
```
|
103 |
+
/path/to/imagenet/:
|
104 |
+
train/:
|
105 |
+
n01440764:
|
106 |
+
many_images.JPEG ...
|
107 |
+
n01443537:
|
108 |
+
many_images.JPEG ...
|
109 |
+
val/:
|
110 |
+
n01440764:
|
111 |
+
ILSVRC2012_val_00000293.JPEG ...
|
112 |
+
n01443537:
|
113 |
+
ILSVRC2012_val_00000236.JPEG ...
|
114 |
+
```
|
115 |
+
**NOTE: The arg `--data_path=/path/to/imagenet` should be passed to the training script.**
|
116 |
+
</details>
|
117 |
+
|
118 |
+
5. (Optional) install and compile `flash-attn` and `xformers` for faster attention computation. Our code will automatically use them if installed. See [models/basic_var.py#L15-L30](models/basic_var.py#L15-L30).
|
119 |
+
|
120 |
+
|
121 |
+
## Training Scripts
|
122 |
+
|
123 |
+
To train VAR-{d16, d20, d24, d30, d36-s} on ImageNet 256x256 or 512x512, you can run the following command:
|
124 |
+
```shell
|
125 |
+
# d16, 256x256
|
126 |
+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
|
127 |
+
--depth=16 --bs=768 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1
|
128 |
+
# d20, 256x256
|
129 |
+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
|
130 |
+
--depth=20 --bs=768 --ep=250 --fp16=1 --alng=1e-3 --wpe=0.1
|
131 |
+
# d24, 256x256
|
132 |
+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
|
133 |
+
--depth=24 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-4 --wpe=0.01
|
134 |
+
# d30, 256x256
|
135 |
+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
|
136 |
+
--depth=30 --bs=1024 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-5 --wpe=0.01 --twde=0.08
|
137 |
+
# d36-s, 512x512 (-s means saln=1, shared AdaLN)
|
138 |
+
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
|
139 |
+
--depth=36 --saln=1 --pn=512 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=5e-6 --wpe=0.01 --twde=0.08
|
140 |
+
```
|
141 |
+
A folder named `local_output` will be created to save the checkpoints and logs.
|
142 |
+
You can monitor the training process by checking the logs in `local_output/log.txt` and `local_output/stdout.txt`, or using `tensorboard --logdir=local_output/`.
|
143 |
+
|
144 |
+
If your experiment is interrupted, just rerun the command, and the training will **automatically resume** from the last checkpoint in `local_output/ckpt*.pth` (see [utils/misc.py#L344-L357](utils/misc.py#L344-L357)).
|
145 |
+
|
146 |
+
## Sampling & Zero-shot Inference
|
147 |
+
|
148 |
+
For FID evaluation, use `var.autoregressive_infer_cfg(..., cfg=1.5, top_p=0.96, top_k=900, more_smooth=False)` to sample 50,000 images (50 per class) and save them as PNG (not JPEG) files in a folder. Pack them into a `.npz` file via `create_npz_from_sample_folder(sample_folder)` in [utils/misc.py#L344](utils/misc.py#L360).
|
149 |
+
Then use the [OpenAI's FID evaluation toolkit](https://github.com/openai/guided-diffusion/tree/main/evaluations) and reference ground truth npz file of [256x256](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz) or [512x512](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz) to evaluate FID, IS, precision, and recall.
|
150 |
+
|
151 |
+
Note a relatively small `cfg=1.5` is used for trade-off between image quality and diversity. You can adjust it to `cfg=5.0`, or sample with `autoregressive_infer_cfg(..., more_smooth=True)` for **better visual quality**.
|
152 |
+
We'll provide the sampling script later.
|
153 |
+
|
154 |
+
## License
|
155 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
156 |
+
|
157 |
+
|
158 |
+
## Citation
|
159 |
+
If our work assists your research, feel free to give us a star ⭐ or cite us using:
|
160 |
+
```
|
161 |
+
@Article{VAR,
|
162 |
+
title={Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction},
|
163 |
+
author={Keyu Tian and Yi Jiang and Zehuan Yuan and Bingyue Peng and Liwei Wang},
|
164 |
+
year={2024},
|
165 |
+
eprint={2404.02905},
|
166 |
+
archivePrefix={arXiv},
|
167 |
+
primaryClass={cs.CV}
|
168 |
+
}
|
169 |
+
```
|
VAR/demo_sample.ipynb
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"source": [
|
6 |
+
"### 🚀 For an interactive experience, head over to our [demo platform](https://var.vision/demo) and dive right in! 🌟"
|
7 |
+
],
|
8 |
+
"metadata": {
|
9 |
+
"collapsed": false
|
10 |
+
}
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 3,
|
15 |
+
"outputs": [
|
16 |
+
{
|
17 |
+
"name": "stderr",
|
18 |
+
"output_type": "stream",
|
19 |
+
"text": [
|
20 |
+
"/var/folders/xv/4sfwyf3j1q72_7wzmsc7s2t00000gn/T/ipykernel_17646/3181200107.py:32: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
21 |
+
" vae.load_state_dict(torch.load(vae_ckpt, map_location=device), strict=True)\n",
|
22 |
+
"/var/folders/xv/4sfwyf3j1q72_7wzmsc7s2t00000gn/T/ipykernel_17646/3181200107.py:33: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
23 |
+
" var.load_state_dict(torch.load(var_ckpt, map_location=device), strict=True)\n"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"name": "stdout",
|
28 |
+
"output_type": "stream",
|
29 |
+
"text": [
|
30 |
+
"prepare finished.\n"
|
31 |
+
]
|
32 |
+
}
|
33 |
+
],
|
34 |
+
"source": [
|
35 |
+
"################## 1. Download checkpoints and build models\n",
|
36 |
+
"import os\n",
|
37 |
+
"import os.path as osp\n",
|
38 |
+
"import torch, torchvision\n",
|
39 |
+
"import random\n",
|
40 |
+
"import numpy as np\n",
|
41 |
+
"import PIL.Image as PImage, PIL.ImageDraw as PImageDraw\n",
|
42 |
+
"setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed\n",
|
43 |
+
"from models import VQVAE, build_vae_var\n",
|
44 |
+
"\n",
|
45 |
+
"MODEL_DEPTH = 16 # TODO: =====> please specify MODEL_DEPTH <=====\n",
|
46 |
+
"assert MODEL_DEPTH in {16, 20, 24, 30}\n",
|
47 |
+
"\n",
|
48 |
+
"\n",
|
49 |
+
"# download checkpoint\n",
|
50 |
+
"hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'\n",
|
51 |
+
"vae_ckpt, var_ckpt = '/Users/mac/Downloads/vae_ch160v4096z32.pth', '/Users/mac/Downloads/var_d16.pth'\n",
|
52 |
+
"if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}')\n",
|
53 |
+
"if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}')\n",
|
54 |
+
"\n",
|
55 |
+
"# build vae, var\n",
|
56 |
+
"patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)\n",
|
57 |
+
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
58 |
+
"if 'vae' not in globals() or 'var' not in globals():\n",
|
59 |
+
" vae, var = build_vae_var(\n",
|
60 |
+
" V=4096, Cvae=32, ch=160, share_quant_resi=4, # hard-coded VQVAE hyperparameters\n",
|
61 |
+
" device=device, patch_nums=patch_nums,\n",
|
62 |
+
" num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,\n",
|
63 |
+
" )\n",
|
64 |
+
"\n",
|
65 |
+
"# load checkpoints\n",
|
66 |
+
"vae.load_state_dict(torch.load(vae_ckpt, map_location=device), strict=True)\n",
|
67 |
+
"var.load_state_dict(torch.load(var_ckpt, map_location=device), strict=True)\n",
|
68 |
+
"vae.eval(), var.eval()\n",
|
69 |
+
"for p in vae.parameters(): p.requires_grad_(False)\n",
|
70 |
+
"for p in var.parameters(): p.requires_grad_(False)\n",
|
71 |
+
"print(f'prepare finished.')"
|
72 |
+
],
|
73 |
+
"metadata": {
|
74 |
+
"collapsed": false,
|
75 |
+
"ExecuteTime": {
|
76 |
+
"end_time": "2024-11-09T23:44:58.680070Z",
|
77 |
+
"start_time": "2024-11-09T23:44:58.094258Z"
|
78 |
+
}
|
79 |
+
}
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"cell_type": "code",
|
83 |
+
"execution_count": 4,
|
84 |
+
"outputs": [
|
85 |
+
{
|
86 |
+
"ename": "RuntimeError",
|
87 |
+
"evalue": "Placeholder storage has not been allocated on MPS device!",
|
88 |
+
"output_type": "error",
|
89 |
+
"traceback": [
|
90 |
+
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
|
91 |
+
"\u001B[0;31mRuntimeError\u001B[0m Traceback (most recent call last)",
|
92 |
+
"Cell \u001B[0;32mIn[4], line 28\u001B[0m\n\u001B[1;32m 26\u001B[0m label_B: torch\u001B[38;5;241m.\u001B[39mLongTensor \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mtensor(class_labels, device\u001B[38;5;241m=\u001B[39mdevice)\n\u001B[1;32m 27\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m torch\u001B[38;5;241m.\u001B[39minference_mode():\n\u001B[0;32m---> 28\u001B[0m recon_B3HW \u001B[38;5;241m=\u001B[39m \u001B[43mvar\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mautoregressive_infer_cfg\u001B[49m\u001B[43m(\u001B[49m\u001B[43mB\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mB\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mlabel_B\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mlabel_B\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcfg\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mcfg\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtop_k\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m900\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtop_p\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m0.95\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mg_seed\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mseed\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmore_smooth\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mmore_smooth\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 30\u001B[0m chw \u001B[38;5;241m=\u001B[39m torchvision\u001B[38;5;241m.\u001B[39mutils\u001B[38;5;241m.\u001B[39mmake_grid(recon_B3HW, nrow\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m8\u001B[39m, padding\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0\u001B[39m, pad_value\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1.0\u001B[39m)\n\u001B[1;32m 31\u001B[0m chw \u001B[38;5;241m=\u001B[39m chw\u001B[38;5;241m.\u001B[39mpermute(\u001B[38;5;241m1\u001B[39m, \u001B[38;5;241m2\u001B[39m, \u001B[38;5;241m0\u001B[39m)\u001B[38;5;241m.\u001B[39mmul_(\u001B[38;5;241m255\u001B[39m)\u001B[38;5;241m.\u001B[39mcpu()\u001B[38;5;241m.\u001B[39mnumpy()\n",
|
93 |
+
"File \u001B[0;32m~/PycharmProjects/ALL_Shit/venv/lib/python3.9/site-packages/torch/utils/_contextlib.py:116\u001B[0m, in \u001B[0;36mcontext_decorator.<locals>.decorate_context\u001B[0;34m(*args, **kwargs)\u001B[0m\n\u001B[1;32m 113\u001B[0m \u001B[38;5;129m@functools\u001B[39m\u001B[38;5;241m.\u001B[39mwraps(func)\n\u001B[1;32m 114\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mdecorate_context\u001B[39m(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[1;32m 115\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m ctx_factory():\n\u001B[0;32m--> 116\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
|
94 |
+
"File \u001B[0;32m~/PycharmProjects/VAR_clip/VAR/models/var.py:154\u001B[0m, in \u001B[0;36mVAR.autoregressive_infer_cfg\u001B[0;34m(self, B, label_B, cond_delta, g_seed, cfg, top_k, top_p, beta, more_smooth)\u001B[0m\n\u001B[1;32m 151\u001B[0m \u001B[38;5;28;01melif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(label_B, \u001B[38;5;28mint\u001B[39m):\n\u001B[1;32m 152\u001B[0m label_B \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mfull((B,), fill_value\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mnum_classes \u001B[38;5;28;01mif\u001B[39;00m label_B \u001B[38;5;241m<\u001B[39m \u001B[38;5;241m0\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m label_B, device\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mlvl_1L\u001B[38;5;241m.\u001B[39mdevice)\n\u001B[0;32m--> 154\u001B[0m sos \u001B[38;5;241m=\u001B[39m cond_BD \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mclass_emb\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mcat\u001B[49m\u001B[43m(\u001B[49m\u001B[43m(\u001B[49m\u001B[43mlabel_B\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfull_like\u001B[49m\u001B[43m(\u001B[49m\u001B[43mlabel_B\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mfill_value\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mnum_classes\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 155\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m cond_delta \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m 156\u001B[0m cond_BD[:B] \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m cond_delta \u001B[38;5;241m*\u001B[39m beta\n",
|
95 |
+
"File \u001B[0;32m~/PycharmProjects/ALL_Shit/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1553\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[0;34m(self, *args, **kwargs)\u001B[0m\n\u001B[1;32m 1551\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[1;32m 1552\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m-> 1553\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
|
96 |
+
"File \u001B[0;32m~/PycharmProjects/ALL_Shit/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1562\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[0;34m(self, *args, **kwargs)\u001B[0m\n\u001B[1;32m 1557\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[1;32m 1558\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[1;32m 1559\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[1;32m 1560\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[1;32m 1561\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[0;32m-> 1562\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1564\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[1;32m 1565\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n",
|
97 |
+
"File \u001B[0;32m~/PycharmProjects/ALL_Shit/venv/lib/python3.9/site-packages/torch/nn/modules/sparse.py:164\u001B[0m, in \u001B[0;36mEmbedding.forward\u001B[0;34m(self, input)\u001B[0m\n\u001B[1;32m 163\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mforward\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;28minput\u001B[39m: Tensor) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m Tensor:\n\u001B[0;32m--> 164\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mF\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43membedding\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 165\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43minput\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mweight\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mpadding_idx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mmax_norm\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 166\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mnorm_type\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mscale_grad_by_freq\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msparse\u001B[49m\u001B[43m)\u001B[49m\n",
|
98 |
+
"File \u001B[0;32m~/PycharmProjects/ALL_Shit/venv/lib/python3.9/site-packages/torch/nn/functional.py:2267\u001B[0m, in \u001B[0;36membedding\u001B[0;34m(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)\u001B[0m\n\u001B[1;32m 2261\u001B[0m \u001B[38;5;66;03m# Note [embedding_renorm set_grad_enabled]\u001B[39;00m\n\u001B[1;32m 2262\u001B[0m \u001B[38;5;66;03m# XXX: equivalent to\u001B[39;00m\n\u001B[1;32m 2263\u001B[0m \u001B[38;5;66;03m# with torch.no_grad():\u001B[39;00m\n\u001B[1;32m 2264\u001B[0m \u001B[38;5;66;03m# torch.embedding_renorm_\u001B[39;00m\n\u001B[1;32m 2265\u001B[0m \u001B[38;5;66;03m# remove once script supports set_grad_enabled\u001B[39;00m\n\u001B[1;32m 2266\u001B[0m _no_grad_embedding_renorm_(weight, \u001B[38;5;28minput\u001B[39m, max_norm, norm_type)\n\u001B[0;32m-> 2267\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43membedding\u001B[49m\u001B[43m(\u001B[49m\u001B[43mweight\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43minput\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpadding_idx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mscale_grad_by_freq\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msparse\u001B[49m\u001B[43m)\u001B[49m\n",
|
99 |
+
"\u001B[0;31mRuntimeError\u001B[0m: Placeholder storage has not been allocated on MPS device!"
|
100 |
+
]
|
101 |
+
}
|
102 |
+
],
|
103 |
+
"source": [
|
104 |
+
"############################# 2. Sample with classifier-free guidance\n",
|
105 |
+
"\n",
|
106 |
+
"# set args\n",
|
107 |
+
"seed = 0 #@param {type:\"number\"}\n",
|
108 |
+
"torch.manual_seed(seed)\n",
|
109 |
+
"num_sampling_steps = 250 #@param {type:\"slider\", min:0, max:1000, step:1}\n",
|
110 |
+
"cfg = 4 #@param {type:\"slider\", min:1, max:10, step:0.1}\n",
|
111 |
+
"class_labels = (980, 980, 437, 437, 22, 22, 562, 562) #@param {type:\"raw\"}\n",
|
112 |
+
"more_smooth = False # True for more smooth output\n",
|
113 |
+
"\n",
|
114 |
+
"# seed\n",
|
115 |
+
"torch.manual_seed(seed)\n",
|
116 |
+
"random.seed(seed)\n",
|
117 |
+
"np.random.seed(seed)\n",
|
118 |
+
"torch.backends.cudnn.deterministic = True\n",
|
119 |
+
"torch.backends.cudnn.benchmark = False\n",
|
120 |
+
"\n",
|
121 |
+
"# run faster\n",
|
122 |
+
"tf32 = True\n",
|
123 |
+
"torch.backends.cudnn.allow_tf32 = bool(tf32)\n",
|
124 |
+
"torch.backends.cuda.matmul.allow_tf32 = bool(tf32)\n",
|
125 |
+
"torch.set_float32_matmul_precision('high' if tf32 else 'highest')\n",
|
126 |
+
"\n",
|
127 |
+
"# sample\n",
|
128 |
+
"B = len(class_labels)\n",
|
129 |
+
"label_B: torch.LongTensor = torch.tensor(class_labels, device=device)\n",
|
130 |
+
"with torch.inference_mode():\n",
|
131 |
+
" recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth)\n",
|
132 |
+
"\n",
|
133 |
+
"chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0)\n",
|
134 |
+
"chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()\n",
|
135 |
+
"chw = PImage.fromarray(chw.astype(np.uint8))\n",
|
136 |
+
"chw.show()\n"
|
137 |
+
],
|
138 |
+
"metadata": {
|
139 |
+
"collapsed": false,
|
140 |
+
"ExecuteTime": {
|
141 |
+
"end_time": "2024-11-09T23:44:59.930244Z",
|
142 |
+
"start_time": "2024-11-09T23:44:59.454228Z"
|
143 |
+
}
|
144 |
+
}
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"cell_type": "code",
|
148 |
+
"outputs": [],
|
149 |
+
"source": [],
|
150 |
+
"metadata": {
|
151 |
+
"collapsed": false
|
152 |
+
}
|
153 |
+
}
|
154 |
+
],
|
155 |
+
"metadata": {
|
156 |
+
"kernelspec": {
|
157 |
+
"display_name": "Python 3",
|
158 |
+
"language": "python",
|
159 |
+
"name": "python3"
|
160 |
+
},
|
161 |
+
"language_info": {
|
162 |
+
"codemirror_mode": {
|
163 |
+
"name": "ipython",
|
164 |
+
"version": 2
|
165 |
+
},
|
166 |
+
"file_extension": ".py",
|
167 |
+
"mimetype": "text/x-python",
|
168 |
+
"name": "python",
|
169 |
+
"nbconvert_exporter": "python",
|
170 |
+
"pygments_lexer": "ipython2",
|
171 |
+
"version": "2.7.6"
|
172 |
+
}
|
173 |
+
},
|
174 |
+
"nbformat": 4,
|
175 |
+
"nbformat_minor": 0
|
176 |
+
}
|
VAR/dist.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from typing import List
|
6 |
+
from typing import Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.distributed as tdist
|
10 |
+
import torch.multiprocessing as mp
|
11 |
+
|
12 |
+
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu'
|
13 |
+
__initialized = False
|
14 |
+
|
15 |
+
|
16 |
+
def initialized():
|
17 |
+
return __initialized
|
18 |
+
|
19 |
+
|
20 |
+
def initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout=30):
|
21 |
+
global __device
|
22 |
+
if not torch.cuda.is_available():
|
23 |
+
print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
|
24 |
+
return
|
25 |
+
elif 'RANK' not in os.environ:
|
26 |
+
torch.cuda.set_device(gpu_id_if_not_distibuted)
|
27 |
+
__device = torch.empty(1).cuda().device
|
28 |
+
print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr)
|
29 |
+
return
|
30 |
+
# then 'RANK' must exist
|
31 |
+
global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
|
32 |
+
local_rank = global_rank % num_gpus
|
33 |
+
torch.cuda.set_device(local_rank)
|
34 |
+
|
35 |
+
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
|
36 |
+
if mp.get_start_method(allow_none=True) is None:
|
37 |
+
method = 'fork' if fork else 'spawn'
|
38 |
+
print(f'[dist initialize] mp method={method}')
|
39 |
+
mp.set_start_method(method)
|
40 |
+
tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout*60))
|
41 |
+
|
42 |
+
global __rank, __local_rank, __world_size, __initialized
|
43 |
+
__local_rank = local_rank
|
44 |
+
__rank, __world_size = tdist.get_rank(), tdist.get_world_size()
|
45 |
+
__device = torch.empty(1).cuda().device
|
46 |
+
__initialized = True
|
47 |
+
|
48 |
+
assert tdist.is_initialized(), 'torch.distributed is not initialized!'
|
49 |
+
print(f'[lrk={get_local_rank()}, rk={get_rank()}]')
|
50 |
+
|
51 |
+
|
52 |
+
def get_rank():
|
53 |
+
return __rank
|
54 |
+
|
55 |
+
|
56 |
+
def get_local_rank():
|
57 |
+
return __local_rank
|
58 |
+
|
59 |
+
|
60 |
+
def get_world_size():
|
61 |
+
return __world_size
|
62 |
+
|
63 |
+
|
64 |
+
def get_device():
|
65 |
+
return __device
|
66 |
+
|
67 |
+
|
68 |
+
def set_gpu_id(gpu_id: int):
|
69 |
+
if gpu_id is None: return
|
70 |
+
global __device
|
71 |
+
if isinstance(gpu_id, (str, int)):
|
72 |
+
torch.cuda.set_device(int(gpu_id))
|
73 |
+
__device = torch.empty(1).cuda().device
|
74 |
+
else:
|
75 |
+
raise NotImplementedError
|
76 |
+
|
77 |
+
|
78 |
+
def is_master():
|
79 |
+
return __rank == 0
|
80 |
+
|
81 |
+
|
82 |
+
def is_local_master():
|
83 |
+
return __local_rank == 0
|
84 |
+
|
85 |
+
|
86 |
+
def new_group(ranks: List[int]):
|
87 |
+
if __initialized:
|
88 |
+
return tdist.new_group(ranks=ranks)
|
89 |
+
return None
|
90 |
+
|
91 |
+
|
92 |
+
def barrier():
|
93 |
+
if __initialized:
|
94 |
+
tdist.barrier()
|
95 |
+
|
96 |
+
|
97 |
+
def allreduce(t: torch.Tensor, async_op=False):
|
98 |
+
if __initialized:
|
99 |
+
if not t.is_cuda:
|
100 |
+
cu = t.detach().cuda()
|
101 |
+
ret = tdist.all_reduce(cu, async_op=async_op)
|
102 |
+
t.copy_(cu.cpu())
|
103 |
+
else:
|
104 |
+
ret = tdist.all_reduce(t, async_op=async_op)
|
105 |
+
return ret
|
106 |
+
return None
|
107 |
+
|
108 |
+
|
109 |
+
def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
|
110 |
+
if __initialized:
|
111 |
+
if not t.is_cuda:
|
112 |
+
t = t.cuda()
|
113 |
+
ls = [torch.empty_like(t) for _ in range(__world_size)]
|
114 |
+
tdist.all_gather(ls, t)
|
115 |
+
else:
|
116 |
+
ls = [t]
|
117 |
+
if cat:
|
118 |
+
ls = torch.cat(ls, dim=0)
|
119 |
+
return ls
|
120 |
+
|
121 |
+
|
122 |
+
def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
|
123 |
+
if __initialized:
|
124 |
+
if not t.is_cuda:
|
125 |
+
t = t.cuda()
|
126 |
+
|
127 |
+
t_size = torch.tensor(t.size(), device=t.device)
|
128 |
+
ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]
|
129 |
+
tdist.all_gather(ls_size, t_size)
|
130 |
+
|
131 |
+
max_B = max(size[0].item() for size in ls_size)
|
132 |
+
pad = max_B - t_size[0].item()
|
133 |
+
if pad:
|
134 |
+
pad_size = (pad, *t.size()[1:])
|
135 |
+
t = torch.cat((t, t.new_empty(pad_size)), dim=0)
|
136 |
+
|
137 |
+
ls_padded = [torch.empty_like(t) for _ in range(__world_size)]
|
138 |
+
tdist.all_gather(ls_padded, t)
|
139 |
+
ls = []
|
140 |
+
for t, size in zip(ls_padded, ls_size):
|
141 |
+
ls.append(t[:size[0].item()])
|
142 |
+
else:
|
143 |
+
ls = [t]
|
144 |
+
if cat:
|
145 |
+
ls = torch.cat(ls, dim=0)
|
146 |
+
return ls
|
147 |
+
|
148 |
+
|
149 |
+
def broadcast(t: torch.Tensor, src_rank) -> None:
|
150 |
+
if __initialized:
|
151 |
+
if not t.is_cuda:
|
152 |
+
cu = t.detach().cuda()
|
153 |
+
tdist.broadcast(cu, src=src_rank)
|
154 |
+
t.copy_(cu.cpu())
|
155 |
+
else:
|
156 |
+
tdist.broadcast(t, src=src_rank)
|
157 |
+
|
158 |
+
|
159 |
+
def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
|
160 |
+
if not initialized():
|
161 |
+
return torch.tensor([val]) if fmt is None else [fmt % val]
|
162 |
+
|
163 |
+
ts = torch.zeros(__world_size)
|
164 |
+
ts[__rank] = val
|
165 |
+
allreduce(ts)
|
166 |
+
if fmt is None:
|
167 |
+
return ts
|
168 |
+
return [fmt % v for v in ts.cpu().numpy().tolist()]
|
169 |
+
|
170 |
+
|
171 |
+
def master_only(func):
|
172 |
+
@functools.wraps(func)
|
173 |
+
def wrapper(*args, **kwargs):
|
174 |
+
force = kwargs.pop('force', False)
|
175 |
+
if force or is_master():
|
176 |
+
ret = func(*args, **kwargs)
|
177 |
+
else:
|
178 |
+
ret = None
|
179 |
+
barrier()
|
180 |
+
return ret
|
181 |
+
return wrapper
|
182 |
+
|
183 |
+
|
184 |
+
def local_master_only(func):
|
185 |
+
@functools.wraps(func)
|
186 |
+
def wrapper(*args, **kwargs):
|
187 |
+
force = kwargs.pop('force', False)
|
188 |
+
if force or is_local_master():
|
189 |
+
ret = func(*args, **kwargs)
|
190 |
+
else:
|
191 |
+
ret = None
|
192 |
+
barrier()
|
193 |
+
return ret
|
194 |
+
return wrapper
|
195 |
+
|
196 |
+
|
197 |
+
def for_visualize(func):
|
198 |
+
@functools.wraps(func)
|
199 |
+
def wrapper(*args, **kwargs):
|
200 |
+
if is_master():
|
201 |
+
# with torch.no_grad():
|
202 |
+
ret = func(*args, **kwargs)
|
203 |
+
else:
|
204 |
+
ret = None
|
205 |
+
return ret
|
206 |
+
return wrapper
|
207 |
+
|
208 |
+
|
209 |
+
def finalize():
|
210 |
+
if __initialized:
|
211 |
+
tdist.destroy_process_group()
|
VAR/models/__init__.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .quant import VectorQuantizer2
|
5 |
+
from .var import VAR
|
6 |
+
from .vqvae import VQVAE
|
7 |
+
|
8 |
+
|
9 |
+
def build_vae_var(
|
10 |
+
# Shared args
|
11 |
+
device, patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
|
12 |
+
# VQVAE args
|
13 |
+
V=4096, Cvae=32, ch=160, share_quant_resi=4,
|
14 |
+
# VAR args
|
15 |
+
num_classes=1000, depth=16, shared_aln=False, attn_l2_norm=True,
|
16 |
+
flash_if_available=True, fused_if_available=True,
|
17 |
+
init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=-1, # init_std < 0: automated
|
18 |
+
) -> Tuple[VQVAE, VAR]:
|
19 |
+
heads = depth
|
20 |
+
width = depth * 64
|
21 |
+
dpr = 0.1 * depth/24
|
22 |
+
|
23 |
+
# disable built-in initialization for speed
|
24 |
+
for clz in (nn.Linear, nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm, nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d, nn.ConvTranspose2d):
|
25 |
+
setattr(clz, 'reset_parameters', lambda self: None)
|
26 |
+
|
27 |
+
# build models
|
28 |
+
vae_local = VQVAE(vocab_size=V, z_channels=Cvae, ch=ch, test_mode=True, share_quant_resi=share_quant_resi, v_patch_nums=patch_nums).to(device)
|
29 |
+
var_wo_ddp = VAR(
|
30 |
+
vae_local=vae_local,
|
31 |
+
num_classes=num_classes, depth=depth, embed_dim=width, num_heads=heads, drop_rate=0., attn_drop_rate=0., drop_path_rate=dpr,
|
32 |
+
norm_eps=1e-6, shared_aln=shared_aln, cond_drop_rate=0.1,
|
33 |
+
attn_l2_norm=attn_l2_norm,
|
34 |
+
patch_nums=patch_nums,
|
35 |
+
flash_if_available=flash_if_available, fused_if_available=fused_if_available,
|
36 |
+
).to(device)
|
37 |
+
var_wo_ddp.init_weights(init_adaln=init_adaln, init_adaln_gamma=init_adaln_gamma, init_head=init_head, init_std=init_std)
|
38 |
+
|
39 |
+
return vae_local, var_wo_ddp
|
VAR/models/basic_vae.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
# this file only provides the 2 modules used in VQVAE
|
7 |
+
__all__ = ['Encoder', 'Decoder',]
|
8 |
+
|
9 |
+
|
10 |
+
"""
|
11 |
+
References: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py
|
12 |
+
"""
|
13 |
+
# swish
|
14 |
+
def nonlinearity(x):
|
15 |
+
return x * torch.sigmoid(x)
|
16 |
+
|
17 |
+
|
18 |
+
def Normalize(in_channels, num_groups=32):
|
19 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
20 |
+
|
21 |
+
|
22 |
+
class Upsample2x(nn.Module):
|
23 |
+
def __init__(self, in_channels):
|
24 |
+
super().__init__()
|
25 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
|
29 |
+
|
30 |
+
|
31 |
+
class Downsample2x(nn.Module):
|
32 |
+
def __init__(self, in_channels):
|
33 |
+
super().__init__()
|
34 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
return self.conv(F.pad(x, pad=(0, 1, 0, 1), mode='constant', value=0))
|
38 |
+
|
39 |
+
|
40 |
+
class ResnetBlock(nn.Module):
|
41 |
+
def __init__(self, *, in_channels, out_channels=None, dropout): # conv_shortcut=False, # conv_shortcut: always False in VAE
|
42 |
+
super().__init__()
|
43 |
+
self.in_channels = in_channels
|
44 |
+
out_channels = in_channels if out_channels is None else out_channels
|
45 |
+
self.out_channels = out_channels
|
46 |
+
|
47 |
+
self.norm1 = Normalize(in_channels)
|
48 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
49 |
+
self.norm2 = Normalize(out_channels)
|
50 |
+
self.dropout = torch.nn.Dropout(dropout) if dropout > 1e-6 else nn.Identity()
|
51 |
+
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
52 |
+
if self.in_channels != self.out_channels:
|
53 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
54 |
+
else:
|
55 |
+
self.nin_shortcut = nn.Identity()
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
h = self.conv1(F.silu(self.norm1(x), inplace=True))
|
59 |
+
h = self.conv2(self.dropout(F.silu(self.norm2(h), inplace=True)))
|
60 |
+
return self.nin_shortcut(x) + h
|
61 |
+
|
62 |
+
|
63 |
+
class AttnBlock(nn.Module):
|
64 |
+
def __init__(self, in_channels):
|
65 |
+
super().__init__()
|
66 |
+
self.C = in_channels
|
67 |
+
|
68 |
+
self.norm = Normalize(in_channels)
|
69 |
+
self.qkv = torch.nn.Conv2d(in_channels, 3*in_channels, kernel_size=1, stride=1, padding=0)
|
70 |
+
self.w_ratio = int(in_channels) ** (-0.5)
|
71 |
+
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
qkv = self.qkv(self.norm(x))
|
75 |
+
B, _, H, W = qkv.shape # should be B,3C,H,W
|
76 |
+
C = self.C
|
77 |
+
q, k, v = qkv.reshape(B, 3, C, H, W).unbind(1)
|
78 |
+
|
79 |
+
# compute attention
|
80 |
+
q = q.view(B, C, H * W).contiguous()
|
81 |
+
q = q.permute(0, 2, 1).contiguous() # B,HW,C
|
82 |
+
k = k.view(B, C, H * W).contiguous() # B,C,HW
|
83 |
+
w = torch.bmm(q, k).mul_(self.w_ratio) # B,HW,HW w[B,i,j]=sum_c q[B,i,C]k[B,C,j]
|
84 |
+
w = F.softmax(w, dim=2)
|
85 |
+
|
86 |
+
# attend to values
|
87 |
+
v = v.view(B, C, H * W).contiguous()
|
88 |
+
w = w.permute(0, 2, 1).contiguous() # B,HW,HW (first HW of k, second of q)
|
89 |
+
h = torch.bmm(v, w) # B, C,HW (HW of q) h[B,C,j] = sum_i v[B,C,i] w[B,i,j]
|
90 |
+
h = h.view(B, C, H, W).contiguous()
|
91 |
+
|
92 |
+
return x + self.proj_out(h)
|
93 |
+
|
94 |
+
|
95 |
+
def make_attn(in_channels, using_sa=True):
|
96 |
+
return AttnBlock(in_channels) if using_sa else nn.Identity()
|
97 |
+
|
98 |
+
|
99 |
+
class Encoder(nn.Module):
|
100 |
+
def __init__(
|
101 |
+
self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,
|
102 |
+
dropout=0.0, in_channels=3,
|
103 |
+
z_channels, double_z=False, using_sa=True, using_mid_sa=True,
|
104 |
+
):
|
105 |
+
super().__init__()
|
106 |
+
self.ch = ch
|
107 |
+
self.num_resolutions = len(ch_mult)
|
108 |
+
self.downsample_ratio = 2 ** (self.num_resolutions - 1)
|
109 |
+
self.num_res_blocks = num_res_blocks
|
110 |
+
self.in_channels = in_channels
|
111 |
+
|
112 |
+
# downsampling
|
113 |
+
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
114 |
+
|
115 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
116 |
+
self.down = nn.ModuleList()
|
117 |
+
for i_level in range(self.num_resolutions):
|
118 |
+
block = nn.ModuleList()
|
119 |
+
attn = nn.ModuleList()
|
120 |
+
block_in = ch * in_ch_mult[i_level]
|
121 |
+
block_out = ch * ch_mult[i_level]
|
122 |
+
for i_block in range(self.num_res_blocks):
|
123 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))
|
124 |
+
block_in = block_out
|
125 |
+
if i_level == self.num_resolutions - 1 and using_sa:
|
126 |
+
attn.append(make_attn(block_in, using_sa=True))
|
127 |
+
down = nn.Module()
|
128 |
+
down.block = block
|
129 |
+
down.attn = attn
|
130 |
+
if i_level != self.num_resolutions - 1:
|
131 |
+
down.downsample = Downsample2x(block_in)
|
132 |
+
self.down.append(down)
|
133 |
+
|
134 |
+
# middle
|
135 |
+
self.mid = nn.Module()
|
136 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
|
137 |
+
self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
|
138 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
|
139 |
+
|
140 |
+
# end
|
141 |
+
self.norm_out = Normalize(block_in)
|
142 |
+
self.conv_out = torch.nn.Conv2d(block_in, (2 * z_channels if double_z else z_channels), kernel_size=3, stride=1, padding=1)
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
# downsampling
|
146 |
+
h = self.conv_in(x)
|
147 |
+
for i_level in range(self.num_resolutions):
|
148 |
+
for i_block in range(self.num_res_blocks):
|
149 |
+
h = self.down[i_level].block[i_block](h)
|
150 |
+
if len(self.down[i_level].attn) > 0:
|
151 |
+
h = self.down[i_level].attn[i_block](h)
|
152 |
+
if i_level != self.num_resolutions - 1:
|
153 |
+
h = self.down[i_level].downsample(h)
|
154 |
+
|
155 |
+
# middle
|
156 |
+
h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(h)))
|
157 |
+
|
158 |
+
# end
|
159 |
+
h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
|
160 |
+
return h
|
161 |
+
|
162 |
+
|
163 |
+
class Decoder(nn.Module):
|
164 |
+
def __init__(
|
165 |
+
self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,
|
166 |
+
dropout=0.0, in_channels=3, # in_channels: raw img channels
|
167 |
+
z_channels, using_sa=True, using_mid_sa=True,
|
168 |
+
):
|
169 |
+
super().__init__()
|
170 |
+
self.ch = ch
|
171 |
+
self.num_resolutions = len(ch_mult)
|
172 |
+
self.num_res_blocks = num_res_blocks
|
173 |
+
self.in_channels = in_channels
|
174 |
+
|
175 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
176 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
177 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
178 |
+
|
179 |
+
# z to block_in
|
180 |
+
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
181 |
+
|
182 |
+
# middle
|
183 |
+
self.mid = nn.Module()
|
184 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
|
185 |
+
self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
|
186 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
|
187 |
+
|
188 |
+
# upsampling
|
189 |
+
self.up = nn.ModuleList()
|
190 |
+
for i_level in reversed(range(self.num_resolutions)):
|
191 |
+
block = nn.ModuleList()
|
192 |
+
attn = nn.ModuleList()
|
193 |
+
block_out = ch * ch_mult[i_level]
|
194 |
+
for i_block in range(self.num_res_blocks + 1):
|
195 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))
|
196 |
+
block_in = block_out
|
197 |
+
if i_level == self.num_resolutions-1 and using_sa:
|
198 |
+
attn.append(make_attn(block_in, using_sa=True))
|
199 |
+
up = nn.Module()
|
200 |
+
up.block = block
|
201 |
+
up.attn = attn
|
202 |
+
if i_level != 0:
|
203 |
+
up.upsample = Upsample2x(block_in)
|
204 |
+
self.up.insert(0, up) # prepend to get consistent order
|
205 |
+
|
206 |
+
# end
|
207 |
+
self.norm_out = Normalize(block_in)
|
208 |
+
self.conv_out = torch.nn.Conv2d(block_in, in_channels, kernel_size=3, stride=1, padding=1)
|
209 |
+
|
210 |
+
def forward(self, z):
|
211 |
+
# z to block_in
|
212 |
+
# middle
|
213 |
+
h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(self.conv_in(z))))
|
214 |
+
|
215 |
+
# upsampling
|
216 |
+
for i_level in reversed(range(self.num_resolutions)):
|
217 |
+
for i_block in range(self.num_res_blocks + 1):
|
218 |
+
h = self.up[i_level].block[i_block](h)
|
219 |
+
if len(self.up[i_level].attn) > 0:
|
220 |
+
h = self.up[i_level].attn[i_block](h)
|
221 |
+
if i_level != 0:
|
222 |
+
h = self.up[i_level].upsample(h)
|
223 |
+
|
224 |
+
# end
|
225 |
+
h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
|
226 |
+
return h
|
VAR/models/basic_var.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from models.helpers import DropPath, drop_path
|
8 |
+
|
9 |
+
|
10 |
+
# this file only provides the 3 blocks used in VAR transformer
|
11 |
+
__all__ = ['FFN', 'AdaLNSelfAttn', 'AdaLNBeforeHead']
|
12 |
+
|
13 |
+
|
14 |
+
# automatically import fused operators
|
15 |
+
dropout_add_layer_norm = fused_mlp_func = memory_efficient_attention = flash_attn_func = None
|
16 |
+
try:
|
17 |
+
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
18 |
+
from flash_attn.ops.fused_dense import fused_mlp_func
|
19 |
+
except ImportError: pass
|
20 |
+
# automatically import faster attention implementations
|
21 |
+
try: from xformers.ops import memory_efficient_attention
|
22 |
+
except ImportError: pass
|
23 |
+
try: from flash_attn import flash_attn_func # qkv: BLHc, ret: BLHcq
|
24 |
+
except ImportError: pass
|
25 |
+
try: from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc
|
26 |
+
except ImportError:
|
27 |
+
def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p=0.0):
|
28 |
+
attn = query.mul(scale) @ key.transpose(-2, -1) # BHLc @ BHcL => BHLL
|
29 |
+
if attn_mask is not None: attn.add_(attn_mask)
|
30 |
+
return (F.dropout(attn.softmax(dim=-1), p=dropout_p, inplace=True) if dropout_p > 0 else attn.softmax(dim=-1)) @ value
|
31 |
+
|
32 |
+
|
33 |
+
class FFN(nn.Module):
|
34 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_if_available=True):
|
35 |
+
super().__init__()
|
36 |
+
self.fused_mlp_func = fused_mlp_func if fused_if_available else None
|
37 |
+
out_features = out_features or in_features
|
38 |
+
hidden_features = hidden_features or in_features
|
39 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
40 |
+
self.act = nn.GELU(approximate='tanh')
|
41 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
42 |
+
self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity()
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
if self.fused_mlp_func is not None:
|
46 |
+
return self.drop(self.fused_mlp_func(
|
47 |
+
x=x, weight1=self.fc1.weight, weight2=self.fc2.weight, bias1=self.fc1.bias, bias2=self.fc2.bias,
|
48 |
+
activation='gelu_approx', save_pre_act=self.training, return_residual=False, checkpoint_lvl=0,
|
49 |
+
heuristic=0, process_group=None,
|
50 |
+
))
|
51 |
+
else:
|
52 |
+
return self.drop(self.fc2( self.act(self.fc1(x)) ))
|
53 |
+
|
54 |
+
def extra_repr(self) -> str:
|
55 |
+
return f'fused_mlp_func={self.fused_mlp_func is not None}'
|
56 |
+
|
57 |
+
|
58 |
+
class SelfAttention(nn.Module):
|
59 |
+
def __init__(
|
60 |
+
self, block_idx, embed_dim=768, num_heads=12,
|
61 |
+
attn_drop=0., proj_drop=0., attn_l2_norm=False, flash_if_available=True,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
assert embed_dim % num_heads == 0
|
65 |
+
self.block_idx, self.num_heads, self.head_dim = block_idx, num_heads, embed_dim // num_heads # =64
|
66 |
+
self.attn_l2_norm = attn_l2_norm
|
67 |
+
if self.attn_l2_norm:
|
68 |
+
self.scale = 1
|
69 |
+
self.scale_mul_1H11 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True)
|
70 |
+
self.max_scale_mul = torch.log(torch.tensor(100)).item()
|
71 |
+
else:
|
72 |
+
self.scale = 0.25 / math.sqrt(self.head_dim)
|
73 |
+
|
74 |
+
self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
|
75 |
+
self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim))
|
76 |
+
self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
|
77 |
+
|
78 |
+
self.proj = nn.Linear(embed_dim, embed_dim)
|
79 |
+
self.proj_drop = nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity()
|
80 |
+
self.attn_drop: float = attn_drop
|
81 |
+
self.using_flash = flash_if_available and flash_attn_func is not None
|
82 |
+
self.using_xform = flash_if_available and memory_efficient_attention is not None
|
83 |
+
|
84 |
+
# only used during inference
|
85 |
+
self.caching, self.cached_k, self.cached_v = False, None, None
|
86 |
+
|
87 |
+
def kv_caching(self, enable: bool): self.caching, self.cached_k, self.cached_v = enable, None, None
|
88 |
+
|
89 |
+
# NOTE: attn_bias is None during inference because kv cache is enabled
|
90 |
+
def forward(self, x, attn_bias):
|
91 |
+
B, L, C = x.shape
|
92 |
+
|
93 |
+
qkv = F.linear(input=x, weight=self.mat_qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim)
|
94 |
+
main_type = qkv.dtype
|
95 |
+
# qkv: BL3Hc
|
96 |
+
|
97 |
+
using_flash = self.using_flash and attn_bias is None and qkv.dtype != torch.float32
|
98 |
+
if using_flash or self.using_xform: q, k, v = qkv.unbind(dim=2); dim_cat = 1 # q or k or v: BLHc
|
99 |
+
else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); dim_cat = 2 # q or k or v: BHLc
|
100 |
+
|
101 |
+
if self.attn_l2_norm:
|
102 |
+
scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp()
|
103 |
+
if using_flash or self.using_xform: scale_mul = scale_mul.transpose(1, 2) # 1H11 to 11H1
|
104 |
+
q = F.normalize(q, dim=-1).mul(scale_mul)
|
105 |
+
k = F.normalize(k, dim=-1)
|
106 |
+
|
107 |
+
if self.caching:
|
108 |
+
if self.cached_k is None: self.cached_k = k; self.cached_v = v
|
109 |
+
else: k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat); v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat)
|
110 |
+
|
111 |
+
dropout_p = self.attn_drop if self.training else 0.0
|
112 |
+
if using_flash:
|
113 |
+
oup = flash_attn_func(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), dropout_p=dropout_p, softmax_scale=self.scale).view(B, L, C)
|
114 |
+
elif self.using_xform:
|
115 |
+
oup = memory_efficient_attention(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), attn_bias=None if attn_bias is None else attn_bias.to(dtype=main_type).expand(B, self.num_heads, -1, -1), p=dropout_p, scale=self.scale).view(B, L, C)
|
116 |
+
else:
|
117 |
+
oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias, dropout_p=dropout_p).transpose(1, 2).reshape(B, L, C)
|
118 |
+
|
119 |
+
return self.proj_drop(self.proj(oup))
|
120 |
+
# attn = (q @ k.transpose(-2, -1)).add_(attn_bias + self.local_rpb()) # BHLc @ BHcL => BHLL
|
121 |
+
# attn = self.attn_drop(attn.softmax(dim=-1))
|
122 |
+
# oup = (attn @ v).transpose_(1, 2).reshape(B, L, -1) # BHLL @ BHLc = BHLc => BLHc => BLC
|
123 |
+
|
124 |
+
def extra_repr(self) -> str:
|
125 |
+
return f'using_flash={self.using_flash}, using_xform={self.using_xform}, attn_l2_norm={self.attn_l2_norm}'
|
126 |
+
|
127 |
+
|
128 |
+
class AdaLNSelfAttn(nn.Module):
|
129 |
+
def __init__(
|
130 |
+
self, block_idx, last_drop_p, embed_dim, cond_dim, shared_aln: bool, norm_layer,
|
131 |
+
num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., attn_l2_norm=False,
|
132 |
+
flash_if_available=False, fused_if_available=True,
|
133 |
+
):
|
134 |
+
super(AdaLNSelfAttn, self).__init__()
|
135 |
+
self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim
|
136 |
+
self.C, self.D = embed_dim, cond_dim
|
137 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
138 |
+
self.attn = SelfAttention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_l2_norm=attn_l2_norm, flash_if_available=flash_if_available)
|
139 |
+
self.ffn = FFN(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio), drop=drop, fused_if_available=fused_if_available)
|
140 |
+
|
141 |
+
self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
|
142 |
+
self.shared_aln = shared_aln
|
143 |
+
if self.shared_aln:
|
144 |
+
self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
|
145 |
+
else:
|
146 |
+
lin = nn.Linear(cond_dim, 6*embed_dim)
|
147 |
+
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin)
|
148 |
+
|
149 |
+
self.fused_add_norm_fn = None
|
150 |
+
|
151 |
+
# NOTE: attn_bias is None during inference because kv cache is enabled
|
152 |
+
def forward(self, x, cond_BD, attn_bias): # C: embed_dim, D: cond_dim
|
153 |
+
if self.shared_aln:
|
154 |
+
gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
|
155 |
+
else:
|
156 |
+
gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
|
157 |
+
x = x + self.drop_path(self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias=attn_bias ).mul_(gamma1))
|
158 |
+
x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed when FusedMLP is used
|
159 |
+
return x
|
160 |
+
|
161 |
+
def extra_repr(self) -> str:
|
162 |
+
return f'shared_aln={self.shared_aln}'
|
163 |
+
|
164 |
+
|
165 |
+
class AdaLNBeforeHead(nn.Module):
|
166 |
+
def __init__(self, C, D, norm_layer): # C: embed_dim, D: cond_dim
|
167 |
+
super().__init__()
|
168 |
+
self.C, self.D = C, D
|
169 |
+
self.ln_wo_grad = norm_layer(C, elementwise_affine=False)
|
170 |
+
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2*C))
|
171 |
+
|
172 |
+
def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor):
|
173 |
+
scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)
|
174 |
+
return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)
|
VAR/models/helpers.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def sample_with_top_k_top_p_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l)
|
7 |
+
B, l, V = logits_BlV.shape
|
8 |
+
if top_k > 0:
|
9 |
+
idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
|
10 |
+
logits_BlV.masked_fill_(idx_to_remove, -torch.inf)
|
11 |
+
if top_p > 0:
|
12 |
+
sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)
|
13 |
+
sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
|
14 |
+
sorted_idx_to_remove[..., -1:] = False
|
15 |
+
logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf)
|
16 |
+
# sample (have to squeeze cuz torch.multinomial can only be used for 2D tensor)
|
17 |
+
replacement = num_samples >= 0
|
18 |
+
num_samples = abs(num_samples)
|
19 |
+
return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
|
20 |
+
|
21 |
+
|
22 |
+
def gumbel_softmax_with_rng(logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1, rng: torch.Generator = None) -> torch.Tensor:
|
23 |
+
if rng is None:
|
24 |
+
return F.gumbel_softmax(logits=logits, tau=tau, hard=hard, eps=eps, dim=dim)
|
25 |
+
|
26 |
+
gumbels = (-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_(generator=rng).log())
|
27 |
+
gumbels = (logits + gumbels) / tau
|
28 |
+
y_soft = gumbels.softmax(dim)
|
29 |
+
|
30 |
+
if hard:
|
31 |
+
index = y_soft.max(dim, keepdim=True)[1]
|
32 |
+
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
|
33 |
+
ret = y_hard - y_soft.detach() + y_soft
|
34 |
+
else:
|
35 |
+
ret = y_soft
|
36 |
+
return ret
|
37 |
+
|
38 |
+
|
39 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): # taken from timm
|
40 |
+
if drop_prob == 0. or not training: return x
|
41 |
+
keep_prob = 1 - drop_prob
|
42 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
43 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
44 |
+
if keep_prob > 0.0 and scale_by_keep:
|
45 |
+
random_tensor.div_(keep_prob)
|
46 |
+
return x * random_tensor
|
47 |
+
|
48 |
+
|
49 |
+
class DropPath(nn.Module): # taken from timm
|
50 |
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
51 |
+
super(DropPath, self).__init__()
|
52 |
+
self.drop_prob = drop_prob
|
53 |
+
self.scale_by_keep = scale_by_keep
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
57 |
+
|
58 |
+
def extra_repr(self):
|
59 |
+
return f'(drop_prob=...)'
|
VAR/models/quant.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Sequence, Tuple, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch import distributed as tdist, nn as nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
import dist
|
9 |
+
|
10 |
+
# this file only provides the VectorQuantizer2 used in VQVAE
|
11 |
+
__all__ = ['VectorQuantizer2', ]
|
12 |
+
|
13 |
+
|
14 |
+
class VectorQuantizer2(nn.Module):
|
15 |
+
# VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25
|
16 |
+
def __init__(
|
17 |
+
self, vocab_size, Cvae, using_znorm, beta: float = 0.25,
|
18 |
+
default_qresi_counts=0, v_patch_nums=None, quant_resi=0.5, share_quant_resi=4, # share_quant_resi: args.qsr
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
self.vocab_size: int = vocab_size
|
22 |
+
self.Cvae: int = Cvae
|
23 |
+
self.using_znorm: bool = using_znorm
|
24 |
+
self.v_patch_nums: Tuple[int] = v_patch_nums
|
25 |
+
|
26 |
+
self.quant_resi_ratio = quant_resi
|
27 |
+
if share_quant_resi == 0: # non-shared: \phi_{1 to K} for K scales
|
28 |
+
self.quant_resi = PhiNonShared(
|
29 |
+
[(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in
|
30 |
+
range(default_qresi_counts or len(self.v_patch_nums))])
|
31 |
+
elif share_quant_resi == 1: # fully shared: only a single \phi for K scales
|
32 |
+
self.quant_resi = PhiShared(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity())
|
33 |
+
else: # partially shared: \phi_{1 to share_quant_resi} for K scales
|
34 |
+
self.quant_resi = PhiPartiallyShared(nn.ModuleList(
|
35 |
+
[(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in
|
36 |
+
range(share_quant_resi)]))
|
37 |
+
|
38 |
+
self.register_buffer('ema_vocab_hit_SV', torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0))
|
39 |
+
self.record_hit = 0
|
40 |
+
|
41 |
+
self.beta: float = beta
|
42 |
+
self.embedding = nn.Embedding(self.vocab_size, self.Cvae)
|
43 |
+
|
44 |
+
# only used for progressive training of VAR (not supported yet, will be tested and supported in the future)
|
45 |
+
self.prog_si = -1 # progressive training: not supported yet, prog_si always -1
|
46 |
+
|
47 |
+
def eini(self, eini):
|
48 |
+
if eini > 0:
|
49 |
+
nn.init.trunc_normal_(self.embedding.weight.data, std=eini)
|
50 |
+
elif eini < 0:
|
51 |
+
self.embedding.weight.data.uniform_(-abs(eini) / self.vocab_size, abs(eini) / self.vocab_size)
|
52 |
+
|
53 |
+
def extra_repr(self) -> str:
|
54 |
+
return f'{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta} | S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}'
|
55 |
+
|
56 |
+
# ===================== `forward` is only used in VAE training =====================
|
57 |
+
def forward(self, f_BChw: torch.Tensor, ret_usages=False) -> Tuple[torch.Tensor, List[float], torch.Tensor]:
|
58 |
+
dtype = f_BChw.dtype
|
59 |
+
if dtype != torch.float32: f_BChw = f_BChw.float()
|
60 |
+
B, C, H, W = f_BChw.shape
|
61 |
+
f_no_grad = f_BChw.detach()
|
62 |
+
|
63 |
+
f_rest = f_no_grad.clone()
|
64 |
+
f_hat = torch.zeros_like(f_rest)
|
65 |
+
|
66 |
+
with torch.cuda.amp.autocast(enabled=False):
|
67 |
+
mean_vq_loss: torch.Tensor = 0.0
|
68 |
+
vocab_hit_V = torch.zeros(self.vocab_size, dtype=torch.float, device=f_BChw.device)
|
69 |
+
SN = len(self.v_patch_nums)
|
70 |
+
for si, pn in enumerate(self.v_patch_nums): # from small to large
|
71 |
+
# find the nearest embedding
|
72 |
+
if self.using_znorm:
|
73 |
+
rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='bilinear').permute(0, 2, 3, 1).reshape(-1,
|
74 |
+
C) if (
|
75 |
+
si != SN - 1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
|
76 |
+
rest_NC = F.normalize(rest_NC, dim=-1)
|
77 |
+
idx_N = torch.argmax(rest_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
|
78 |
+
else:
|
79 |
+
rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='bilinear').permute(0, 2, 3, 1).reshape(-1,
|
80 |
+
C) if (
|
81 |
+
si != SN - 1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
|
82 |
+
d_no_grad = torch.sum(rest_NC.square(), dim=1, keepdim=True) + torch.sum(
|
83 |
+
self.embedding.weight.data.square(), dim=1, keepdim=False)
|
84 |
+
d_no_grad.addmm_(rest_NC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
|
85 |
+
idx_N = torch.argmin(d_no_grad, dim=1)
|
86 |
+
|
87 |
+
hit_V = idx_N.bincount(minlength=self.vocab_size).float()
|
88 |
+
if self.training:
|
89 |
+
if dist.initialized(): handler = tdist.all_reduce(hit_V, async_op=True)
|
90 |
+
|
91 |
+
# calc loss
|
92 |
+
idx_Bhw = idx_N.view(B, pn, pn)
|
93 |
+
h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W),
|
94 |
+
mode='bilinear').contiguous() if (si != SN - 1) else self.embedding(
|
95 |
+
idx_Bhw).permute(0, 3, 1, 2).contiguous()
|
96 |
+
h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
|
97 |
+
f_hat = f_hat + h_BChw
|
98 |
+
f_rest -= h_BChw
|
99 |
+
|
100 |
+
if self.training and dist.initialized():
|
101 |
+
handler.wait()
|
102 |
+
if self.record_hit == 0:
|
103 |
+
self.ema_vocab_hit_SV[si].copy_(hit_V)
|
104 |
+
elif self.record_hit < 100:
|
105 |
+
self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1))
|
106 |
+
else:
|
107 |
+
self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01))
|
108 |
+
self.record_hit += 1
|
109 |
+
vocab_hit_V.add_(hit_V)
|
110 |
+
mean_vq_loss += F.mse_loss(f_hat.data, f_BChw).mul_(self.beta) + F.mse_loss(f_hat, f_no_grad)
|
111 |
+
|
112 |
+
mean_vq_loss *= 1. / SN
|
113 |
+
f_hat = (f_hat.data - f_no_grad).add_(f_BChw)
|
114 |
+
|
115 |
+
margin = tdist.get_world_size() * (f_BChw.numel() / f_BChw.shape[1]) / self.vocab_size * 0.08
|
116 |
+
# margin = pn*pn / 100
|
117 |
+
if ret_usages:
|
118 |
+
usages = [(self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100 for si, pn in
|
119 |
+
enumerate(self.v_patch_nums)]
|
120 |
+
else:
|
121 |
+
usages = None
|
122 |
+
return f_hat, usages, mean_vq_loss
|
123 |
+
|
124 |
+
# ===================== `forward` is only used in VAE training =====================
|
125 |
+
|
126 |
+
def embed_to_fhat(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False) -> Union[
|
127 |
+
List[torch.Tensor], torch.Tensor]:
|
128 |
+
ls_f_hat_BChw = []
|
129 |
+
B = ms_h_BChw[0].shape[0]
|
130 |
+
H = W = self.v_patch_nums[-1]
|
131 |
+
SN = len(self.v_patch_nums)
|
132 |
+
if all_to_max_scale:
|
133 |
+
f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32)
|
134 |
+
for si, pn in enumerate(self.v_patch_nums): # from small to large
|
135 |
+
h_BChw = ms_h_BChw[si]
|
136 |
+
if si < len(self.v_patch_nums) - 1:
|
137 |
+
h_BChw = F.interpolate(h_BChw, size=(H, W), mode='bilinear')
|
138 |
+
h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
|
139 |
+
f_hat.add_(h_BChw)
|
140 |
+
if last_one:
|
141 |
+
ls_f_hat_BChw = f_hat
|
142 |
+
else:
|
143 |
+
ls_f_hat_BChw.append(f_hat.clone())
|
144 |
+
else:
|
145 |
+
# WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above)
|
146 |
+
# WARNING: this should only be used for experimental purpose
|
147 |
+
f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, self.v_patch_nums[0], self.v_patch_nums[0],
|
148 |
+
dtype=torch.float32)
|
149 |
+
for si, pn in enumerate(self.v_patch_nums): # from small to large
|
150 |
+
f_hat = F.interpolate(f_hat, size=(pn, pn), mode='bilinear')
|
151 |
+
h_BChw = self.quant_resi[si / (SN - 1)](ms_h_BChw[si])
|
152 |
+
f_hat.add_(h_BChw)
|
153 |
+
if last_one:
|
154 |
+
ls_f_hat_BChw = f_hat
|
155 |
+
else:
|
156 |
+
ls_f_hat_BChw.append(f_hat)
|
157 |
+
|
158 |
+
return ls_f_hat_BChw
|
159 |
+
|
160 |
+
def f_to_idxBl_or_fhat(self, f_BChw: torch.Tensor, to_fhat: bool,
|
161 |
+
v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[
|
162 |
+
Union[torch.Tensor, torch.LongTensor]]: # z_BChw is the feature from inp_img_no_grad
|
163 |
+
B, C, H, W = f_BChw.shape
|
164 |
+
f_no_grad = f_BChw.detach()
|
165 |
+
f_rest = f_no_grad.clone()
|
166 |
+
f_hat = torch.zeros_like(f_rest)
|
167 |
+
|
168 |
+
f_hat_or_idx_Bl: List[torch.Tensor] = []
|
169 |
+
|
170 |
+
patch_hws = [(pn, pn) if isinstance(pn, int) else (pn[0], pn[1]) for pn in
|
171 |
+
(v_patch_nums or self.v_patch_nums)] # from small to large
|
172 |
+
assert patch_hws[-1][0] == H and patch_hws[-1][1] == W, f'{patch_hws[-1]=} != ({H=}, {W=})'
|
173 |
+
|
174 |
+
SN = len(patch_hws)
|
175 |
+
for si, (ph, pw) in enumerate(patch_hws): # from small to large
|
176 |
+
if 0 <= self.prog_si < si: break # progressive training: not supported yet, prog_si always -1
|
177 |
+
# find the nearest embedding
|
178 |
+
z_NC = F.interpolate(f_rest, size=(ph, pw), mode='bilinear').permute(0, 2, 3, 1).reshape(-1, C) if (
|
179 |
+
si != SN - 1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
|
180 |
+
if self.using_znorm:
|
181 |
+
z_NC = F.normalize(z_NC, dim=-1)
|
182 |
+
idx_N = torch.argmax(z_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
|
183 |
+
else:
|
184 |
+
d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum(
|
185 |
+
self.embedding.weight.data.square(), dim=1, keepdim=False)
|
186 |
+
d_no_grad.addmm_(z_NC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
|
187 |
+
idx_N = torch.argmin(d_no_grad, dim=1)
|
188 |
+
|
189 |
+
idx_Bhw = idx_N.view(B, ph, pw)
|
190 |
+
h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W),
|
191 |
+
mode='bilinear').contiguous() if (si != SN - 1) else self.embedding(idx_Bhw).permute(
|
192 |
+
0, 3, 1, 2).contiguous()
|
193 |
+
h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
|
194 |
+
f_hat.add_(h_BChw)
|
195 |
+
f_rest.sub_(h_BChw)
|
196 |
+
f_hat_or_idx_Bl.append(f_hat.clone() if to_fhat else idx_N.reshape(B, ph * pw))
|
197 |
+
|
198 |
+
return f_hat_or_idx_Bl
|
199 |
+
|
200 |
+
# ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input =====================
|
201 |
+
def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor:
|
202 |
+
next_scales = []
|
203 |
+
B = gt_ms_idx_Bl[0].shape[0]
|
204 |
+
C = self.Cvae
|
205 |
+
H = W = self.v_patch_nums[-1]
|
206 |
+
SN = len(self.v_patch_nums)
|
207 |
+
|
208 |
+
f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32)
|
209 |
+
pn_next: int = self.v_patch_nums[0]
|
210 |
+
for si in range(SN - 1):
|
211 |
+
if self.prog_si == 0 or (
|
212 |
+
0 <= self.prog_si - 1 < si): break # progressive training: not supported yet, prog_si always -1
|
213 |
+
h_BChw = F.interpolate(self.embedding(gt_ms_idx_Bl[si]).transpose_(1, 2).view(B, C, pn_next, pn_next),
|
214 |
+
size=(H, W), mode='bilinear')
|
215 |
+
f_hat.add_(self.quant_resi[si / (SN - 1)](h_BChw))
|
216 |
+
pn_next = self.v_patch_nums[si + 1]
|
217 |
+
next_scales.append(
|
218 |
+
F.interpolate(f_hat, size=(pn_next, pn_next), mode='bilinear').view(B, C, -1).transpose(1, 2))
|
219 |
+
return torch.cat(next_scales, dim=1) if len(next_scales) else None # cat BlCs to BLC, this should be float32
|
220 |
+
|
221 |
+
# ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input =====================
|
222 |
+
def get_next_autoregressive_input(self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor) -> Tuple[
|
223 |
+
Optional[torch.Tensor], torch.Tensor]: # only used in VAR inference
|
224 |
+
HW = self.v_patch_nums[-1]
|
225 |
+
if si != SN - 1:
|
226 |
+
h = self.quant_resi[si / (SN - 1)](
|
227 |
+
F.interpolate(h_BChw, size=(HW, HW), mode='bilinear')) # conv after upsample
|
228 |
+
f_hat.add_(h)
|
229 |
+
return f_hat, F.interpolate(f_hat, size=(self.v_patch_nums[si + 1], self.v_patch_nums[si + 1]),
|
230 |
+
mode='bilinear')
|
231 |
+
else:
|
232 |
+
h = self.quant_resi[si / (SN - 1)](h_BChw)
|
233 |
+
f_hat.add_(h)
|
234 |
+
return f_hat, f_hat
|
235 |
+
|
236 |
+
|
237 |
+
class Phi(nn.Conv2d):
|
238 |
+
def __init__(self, embed_dim, quant_resi):
|
239 |
+
ks = 3
|
240 |
+
super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2)
|
241 |
+
self.resi_ratio = abs(quant_resi)
|
242 |
+
|
243 |
+
def forward(self, h_BChw):
|
244 |
+
return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
|
245 |
+
|
246 |
+
|
247 |
+
class PhiShared(nn.Module):
|
248 |
+
def __init__(self, qresi: Phi):
|
249 |
+
super().__init__()
|
250 |
+
self.qresi: Phi = qresi
|
251 |
+
|
252 |
+
def __getitem__(self, _) -> Phi:
|
253 |
+
return self.qresi
|
254 |
+
|
255 |
+
|
256 |
+
class PhiPartiallyShared(nn.Module):
|
257 |
+
def __init__(self, qresi_ls: nn.ModuleList):
|
258 |
+
super().__init__()
|
259 |
+
self.qresi_ls = qresi_ls
|
260 |
+
K = len(qresi_ls)
|
261 |
+
self.ticks = np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) if K == 4 else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
|
262 |
+
|
263 |
+
def __getitem__(self, at_from_0_to_1: float) -> Phi:
|
264 |
+
return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()]
|
265 |
+
|
266 |
+
def extra_repr(self) -> str:
|
267 |
+
return f'ticks={self.ticks}'
|
268 |
+
|
269 |
+
|
270 |
+
class PhiNonShared(nn.ModuleList):
|
271 |
+
def __init__(self, qresi: List):
|
272 |
+
super().__init__(qresi)
|
273 |
+
# self.qresi = qresi
|
274 |
+
K = len(qresi)
|
275 |
+
self.ticks = np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) if K == 4 else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
|
276 |
+
|
277 |
+
def __getitem__(self, at_from_0_to_1: float) -> Phi:
|
278 |
+
return super().__getitem__(np.argmin(np.abs(self.ticks - at_from_0_to_1)).item())
|
279 |
+
|
280 |
+
def extra_repr(self) -> str:
|
281 |
+
return f'ticks={self.ticks}'
|
VAR/models/var.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from functools import partial
|
3 |
+
from typing import Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import tqdm
|
8 |
+
from huggingface_hub import PyTorchModelHubMixin
|
9 |
+
|
10 |
+
import dist
|
11 |
+
from models.basic_var import AdaLNBeforeHead, AdaLNSelfAttn
|
12 |
+
from models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_
|
13 |
+
from models.vqvae import VQVAE, VectorQuantizer2
|
14 |
+
import lovely_tensors as lt
|
15 |
+
lt.monkey_patch()
|
16 |
+
|
17 |
+
class SharedAdaLin(nn.Linear):
|
18 |
+
def forward(self, cond_BD):
|
19 |
+
C = self.weight.shape[0] // 6
|
20 |
+
return super().forward(cond_BD).view(-1, 1, 6, C) # B16C
|
21 |
+
|
22 |
+
|
23 |
+
class VAR(nn.Module):
|
24 |
+
def __init__(
|
25 |
+
self, vae_local: VQVAE,
|
26 |
+
num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
27 |
+
norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
|
28 |
+
attn_l2_norm=False,
|
29 |
+
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
|
30 |
+
flash_if_available=True, fused_if_available=True,
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
# 0. hyperparameters
|
34 |
+
assert embed_dim % num_heads == 0
|
35 |
+
self.Cvae, self.V = vae_local.Cvae, vae_local.vocab_size
|
36 |
+
self.depth, self.C, self.D, self.num_heads = depth, embed_dim, embed_dim, num_heads
|
37 |
+
|
38 |
+
self.cond_drop_rate = cond_drop_rate
|
39 |
+
self.prog_si = -1 # progressive training
|
40 |
+
|
41 |
+
self.patch_nums: Tuple[int] = patch_nums
|
42 |
+
self.L = sum(pn ** 2 for pn in self.patch_nums)
|
43 |
+
self.first_l = self.patch_nums[0] ** 2
|
44 |
+
self.begin_ends = []
|
45 |
+
cur = 0
|
46 |
+
for i, pn in enumerate(self.patch_nums):
|
47 |
+
self.begin_ends.append((cur, cur+pn ** 2))
|
48 |
+
cur += pn ** 2
|
49 |
+
|
50 |
+
self.num_stages_minus_1 = len(self.patch_nums) - 1
|
51 |
+
self.rng = torch.Generator(device="mps")
|
52 |
+
|
53 |
+
# 1. input (word) embedding
|
54 |
+
quant: VectorQuantizer2 = vae_local.quantize
|
55 |
+
self.vae_proxy: Tuple[VQVAE] = (vae_local,)
|
56 |
+
self.vae_quant_proxy: Tuple[VectorQuantizer2] = (quant,)
|
57 |
+
self.word_embed = nn.Linear(self.Cvae, self.C)
|
58 |
+
|
59 |
+
# 2. class embedding
|
60 |
+
init_std = math.sqrt(1 / self.C / 3)
|
61 |
+
self.num_classes = num_classes
|
62 |
+
self.uniform_prob = torch.full((1, num_classes), fill_value=1.0 / num_classes, dtype=torch.float32, device=dist.get_device())
|
63 |
+
self.class_emb = nn.Embedding(self.num_classes + 1, self.C)
|
64 |
+
nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std)
|
65 |
+
self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
|
66 |
+
nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
|
67 |
+
|
68 |
+
# 3. absolute position embedding
|
69 |
+
pos_1LC = []
|
70 |
+
for i, pn in enumerate(self.patch_nums):
|
71 |
+
pe = torch.empty(1, pn*pn, self.C)
|
72 |
+
nn.init.trunc_normal_(pe, mean=0, std=init_std)
|
73 |
+
pos_1LC.append(pe)
|
74 |
+
pos_1LC = torch.cat(pos_1LC, dim=1) # 1, L, C
|
75 |
+
assert tuple(pos_1LC.shape) == (1, self.L, self.C)
|
76 |
+
self.pos_1LC = nn.Parameter(pos_1LC)
|
77 |
+
# level embedding (similar to GPT's segment embedding, used to distinguish different levels of token pyramid)
|
78 |
+
self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C)
|
79 |
+
nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
|
80 |
+
|
81 |
+
# 4. backbone blocks
|
82 |
+
self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity()
|
83 |
+
|
84 |
+
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
85 |
+
self.drop_path_rate = drop_path_rate
|
86 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule (linearly increasing)
|
87 |
+
self.blocks = nn.ModuleList([
|
88 |
+
AdaLNSelfAttn(
|
89 |
+
cond_dim=self.D, shared_aln=shared_aln,
|
90 |
+
block_idx=block_idx, embed_dim=self.C, norm_layer=norm_layer, num_heads=num_heads, mlp_ratio=mlp_ratio,
|
91 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[block_idx], last_drop_p=0 if block_idx == 0 else dpr[block_idx-1],
|
92 |
+
attn_l2_norm=attn_l2_norm,
|
93 |
+
flash_if_available=flash_if_available, fused_if_available=fused_if_available,
|
94 |
+
)
|
95 |
+
for block_idx in range(depth)
|
96 |
+
])
|
97 |
+
|
98 |
+
fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks]
|
99 |
+
self.using_fused_add_norm_fn = any(fused_add_norm_fns)
|
100 |
+
print(
|
101 |
+
f'\n[constructor] ==== flash_if_available={flash_if_available} ({sum(b.attn.using_flash for b in self.blocks)}/{self.depth}), fused_if_available={fused_if_available} (fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \n'
|
102 |
+
f' [VAR config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}\n'
|
103 |
+
f' [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})',
|
104 |
+
end='\n\n', flush=True
|
105 |
+
)
|
106 |
+
|
107 |
+
# 5. attention mask used in training (for masking out the future)
|
108 |
+
# it won't be used in inference, since kv cache is enabled
|
109 |
+
d: torch.Tensor = torch.cat([torch.full((pn*pn,), i) for i, pn in enumerate(self.patch_nums)]).view(1, self.L, 1)
|
110 |
+
dT = d.transpose(1, 2) # dT: 11L
|
111 |
+
lvl_1L = dT[:, 0].contiguous()
|
112 |
+
self.register_buffer('lvl_1L', lvl_1L)
|
113 |
+
attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, self.L, self.L)
|
114 |
+
self.register_buffer('attn_bias_for_masking', attn_bias_for_masking.contiguous())
|
115 |
+
|
116 |
+
# 6. classifier head
|
117 |
+
self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer)
|
118 |
+
self.head = nn.Linear(self.C, self.V)
|
119 |
+
|
120 |
+
def get_logits(self, h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], cond_BD: Optional[torch.Tensor]):
|
121 |
+
if not isinstance(h_or_h_and_residual, torch.Tensor):
|
122 |
+
h, resi = h_or_h_and_residual # fused_add_norm must be used
|
123 |
+
h = resi + self.blocks[-1].drop_path(h)
|
124 |
+
else: # fused_add_norm is not used
|
125 |
+
h = h_or_h_and_residual
|
126 |
+
return self.head(self.head_nm(h.float(), cond_BD).float()).float()
|
127 |
+
|
128 |
+
@torch.no_grad()
|
129 |
+
def autoregressive_infer_cfg(
|
130 |
+
self, B: int, label_B: Optional[Union[int, torch.LongTensor]],cond_delta: torch.Tensor =None,
|
131 |
+
g_seed: Optional[int] = None, cfg=1.5, top_k=0, top_p=0.0,beta=0,alpha = 1,
|
132 |
+
more_smooth=False,
|
133 |
+
|
134 |
+
) -> torch.Tensor: # returns reconstructed image (B, 3, H, W) in [0, 1]
|
135 |
+
"""
|
136 |
+
only used for inference, on autoregressive mode
|
137 |
+
:param B: batch size
|
138 |
+
:param label_B: imagenet label; if None, randomly sampled
|
139 |
+
:param g_seed: random seed
|
140 |
+
:param cfg: classifier-free guidance ratio
|
141 |
+
:param top_k: top-k sampling
|
142 |
+
:param top_p: top-p sampling
|
143 |
+
:param more_smooth: smoothing the pred using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
|
144 |
+
:return: if returns_vemb: list of embedding h_BChw := vae_embed(idx_Bl), else: list of idx_Bl
|
145 |
+
"""
|
146 |
+
if g_seed is None: rng = None
|
147 |
+
else: self.rng.manual_seed(g_seed); rng = self.rng
|
148 |
+
|
149 |
+
if label_B is None:
|
150 |
+
label_B = torch.multinomial(self.uniform_prob, num_samples=B, replacement=True, generator=rng).reshape(B)
|
151 |
+
elif isinstance(label_B, int):
|
152 |
+
label_B = torch.full((B,), fill_value=self.num_classes if label_B < 0 else label_B, device=self.lvl_1L.device)
|
153 |
+
|
154 |
+
if alpha == 1:
|
155 |
+
sos = cond_BD = self.class_emb(torch.cat((label_B, torch.full_like(label_B, fill_value=self.num_classes))))
|
156 |
+
else:
|
157 |
+
sos = cond_BD = self.class_emb(torch.cat((label_B, torch.full_like(label_B, fill_value=self.num_classes))))
|
158 |
+
sos[0] = cond_BD[0] = cond_BD[0]*alpha
|
159 |
+
if cond_delta is not None and beta != 0:
|
160 |
+
cond_BD[0] += cond_delta[0] * beta
|
161 |
+
sos[0] += sos[0] * beta
|
162 |
+
|
163 |
+
lvl_pos = self.lvl_embed(self.lvl_1L) + self.pos_1LC
|
164 |
+
next_token_map = sos.unsqueeze(1).expand(2 * B, self.first_l, -1) + self.pos_start.expand(2 * B, self.first_l, -1) + lvl_pos[:, :self.first_l]
|
165 |
+
|
166 |
+
cur_L = 0
|
167 |
+
f_hat = sos.new_zeros(B, self.Cvae, self.patch_nums[-1], self.patch_nums[-1])
|
168 |
+
for b in self.blocks: b.attn.kv_caching(True)
|
169 |
+
for si, pn in enumerate(self.patch_nums): # si: i-th segment
|
170 |
+
ratio = si / self.num_stages_minus_1
|
171 |
+
# last_L = cur_L
|
172 |
+
cur_L += pn*pn
|
173 |
+
# assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item'
|
174 |
+
cond_BD_or_gss = self.shared_ada_lin(cond_BD)
|
175 |
+
x = next_token_map
|
176 |
+
AdaLNSelfAttn.forward
|
177 |
+
for b in self.blocks:
|
178 |
+
x = b(x=x, cond_BD=cond_BD_or_gss, attn_bias=None)
|
179 |
+
logits_BlV = self.get_logits(x, cond_BD)
|
180 |
+
t = cfg * ratio
|
181 |
+
logits_BlV = (1+t) * logits_BlV[:B] - t * logits_BlV[B:]
|
182 |
+
|
183 |
+
idx_Bl = sample_with_top_k_top_p_(logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1)[:, :, 0]
|
184 |
+
if not more_smooth: # this is the default case
|
185 |
+
h_BChw = self.vae_quant_proxy[0].embedding(idx_Bl) # B, l, Cvae
|
186 |
+
else: # not used when evaluating FID/IS/Precision/Recall
|
187 |
+
gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git
|
188 |
+
h_BChw = gumbel_softmax_with_rng(logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng) @ self.vae_quant_proxy[0].embedding.weight.unsqueeze(0)
|
189 |
+
|
190 |
+
h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.Cvae, pn, pn)
|
191 |
+
f_hat, next_token_map = self.vae_quant_proxy[0].get_next_autoregressive_input(si, len(self.patch_nums), f_hat, h_BChw)
|
192 |
+
if si != self.num_stages_minus_1: # prepare for next stage
|
193 |
+
next_token_map = next_token_map.view(B, self.Cvae, -1).transpose(1, 2)
|
194 |
+
next_token_map = self.word_embed(next_token_map) + lvl_pos[:, cur_L:cur_L + self.patch_nums[si+1] ** 2]
|
195 |
+
next_token_map = next_token_map.repeat(2, 1, 1) # double the batch sizes due to CFG
|
196 |
+
|
197 |
+
for b in self.blocks: b.attn.kv_caching(False)
|
198 |
+
return self.vae_proxy[0].fhat_to_img(f_hat).add_(1).mul_(0.5) # de-normalize, from [-1, 1] to [0, 1]
|
199 |
+
|
200 |
+
def forward(self, label_B: torch.LongTensor, x_BLCv_wo_first_l: torch.Tensor,cond_delta: torch.Tensor = None,beta=0.05,alpha = 1) -> torch.Tensor: # returns logits_BLV
|
201 |
+
"""
|
202 |
+
:param label_B: label_B
|
203 |
+
:param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae)
|
204 |
+
:return: logits BLV, V is vocab_size
|
205 |
+
"""
|
206 |
+
bg, ed = self.begin_ends[self.prog_si] if self.prog_si >= 0 else (0, self.L)
|
207 |
+
B = x_BLCv_wo_first_l.shape[0]
|
208 |
+
with torch.cuda.amp.autocast(enabled=False):
|
209 |
+
label_B = torch.where(torch.rand(B, device=label_B.device) < self.cond_drop_rate, self.num_classes, label_B)
|
210 |
+
if cond_delta is not None:
|
211 |
+
sos = cond_BD = alpha * self.class_emb(label_B) + beta * cond_delta
|
212 |
+
else:
|
213 |
+
sos = cond_BD = self.class_emb(label_B)
|
214 |
+
sos = sos.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(B, self.first_l, -1)
|
215 |
+
|
216 |
+
if self.prog_si == 0: x_BLC = sos
|
217 |
+
else: x_BLC = torch.cat((sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1)
|
218 |
+
x_BLC += self.lvl_embed(self.lvl_1L[:, :ed].expand(B, -1)) + self.pos_1LC[:, :ed] # lvl: BLC; pos: 1LC
|
219 |
+
|
220 |
+
attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
|
221 |
+
cond_BD_or_gss = self.shared_ada_lin(cond_BD)
|
222 |
+
|
223 |
+
# hack: get the dtype if mixed precision is used
|
224 |
+
temp = x_BLC.new_ones(8, 8)
|
225 |
+
main_type = torch.matmul(temp, temp).dtype
|
226 |
+
|
227 |
+
x_BLC = x_BLC.to(dtype=main_type)
|
228 |
+
cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)
|
229 |
+
attn_bias = attn_bias.to(dtype=main_type)
|
230 |
+
|
231 |
+
AdaLNSelfAttn.forward
|
232 |
+
for i, b in enumerate(self.blocks):
|
233 |
+
x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=attn_bias)
|
234 |
+
x_BLC = self.get_logits(x_BLC.float(), cond_BD)
|
235 |
+
|
236 |
+
if self.prog_si == 0:
|
237 |
+
if isinstance(self.word_embed, nn.Linear):
|
238 |
+
x_BLC[0, 0, 0] += self.word_embed.weight[0, 0] * 0 + self.word_embed.bias[0] * 0
|
239 |
+
else:
|
240 |
+
s = 0
|
241 |
+
for p in self.word_embed.parameters():
|
242 |
+
if p.requires_grad:
|
243 |
+
s += p.view(-1)[0] * 0
|
244 |
+
x_BLC[0, 0, 0] += s
|
245 |
+
return x_BLC # logits BLV, V is vocab_size
|
246 |
+
|
247 |
+
def init_weights(self, init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=0.02, conv_std_or_gain=0.02):
|
248 |
+
if init_std < 0: init_std = (1 / self.C / 3) ** 0.5 # init_std < 0: automated
|
249 |
+
|
250 |
+
print(f'[init_weights] {type(self).__name__} with {init_std=:g}')
|
251 |
+
for m in self.modules():
|
252 |
+
with_weight = hasattr(m, 'weight') and m.weight is not None
|
253 |
+
with_bias = hasattr(m, 'bias') and m.bias is not None
|
254 |
+
if isinstance(m, nn.Linear):
|
255 |
+
nn.init.trunc_normal_(m.weight.data, std=init_std)
|
256 |
+
if with_bias: m.bias.data.zero_()
|
257 |
+
elif isinstance(m, nn.Embedding):
|
258 |
+
nn.init.trunc_normal_(m.weight.data, std=init_std)
|
259 |
+
if m.padding_idx is not None: m.weight.data[m.padding_idx].zero_()
|
260 |
+
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
261 |
+
if with_weight: m.weight.data.fill_(1.)
|
262 |
+
if with_bias: m.bias.data.zero_()
|
263 |
+
# conv: VAR has no conv, only VQVAE has conv
|
264 |
+
elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
265 |
+
if conv_std_or_gain > 0: nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain)
|
266 |
+
else: nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain)
|
267 |
+
if with_bias: m.bias.data.zero_()
|
268 |
+
|
269 |
+
if init_head >= 0:
|
270 |
+
if isinstance(self.head, nn.Linear):
|
271 |
+
self.head.weight.data.mul_(init_head)
|
272 |
+
self.head.bias.data.zero_()
|
273 |
+
elif isinstance(self.head, nn.Sequential):
|
274 |
+
self.head[-1].weight.data.mul_(init_head)
|
275 |
+
self.head[-1].bias.data.zero_()
|
276 |
+
|
277 |
+
if isinstance(self.head_nm, AdaLNBeforeHead):
|
278 |
+
self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln)
|
279 |
+
if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None:
|
280 |
+
self.head_nm.ada_lin[-1].bias.data.zero_()
|
281 |
+
|
282 |
+
depth = len(self.blocks)
|
283 |
+
for block_idx, sab in enumerate(self.blocks):
|
284 |
+
sab: AdaLNSelfAttn
|
285 |
+
sab.attn.proj.weight.data.div_(math.sqrt(2 * depth))
|
286 |
+
sab.ffn.fc2.weight.data.div_(math.sqrt(2 * depth))
|
287 |
+
if hasattr(sab.ffn, 'fcg') and sab.ffn.fcg is not None:
|
288 |
+
nn.init.ones_(sab.ffn.fcg.bias)
|
289 |
+
nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5)
|
290 |
+
if hasattr(sab, 'ada_lin'):
|
291 |
+
sab.ada_lin[-1].weight.data[2*self.C:].mul_(init_adaln)
|
292 |
+
sab.ada_lin[-1].weight.data[:2*self.C].mul_(init_adaln_gamma)
|
293 |
+
if hasattr(sab.ada_lin[-1], 'bias') and sab.ada_lin[-1].bias is not None:
|
294 |
+
sab.ada_lin[-1].bias.data.zero_()
|
295 |
+
elif hasattr(sab, 'ada_gss'):
|
296 |
+
sab.ada_gss.data[:, :, 2:].mul_(init_adaln)
|
297 |
+
sab.ada_gss.data[:, :, :2].mul_(init_adaln_gamma)
|
298 |
+
|
299 |
+
def extra_repr(self):
|
300 |
+
return f'drop_path_rate={self.drop_path_rate:g}'
|
301 |
+
|
302 |
+
|
303 |
+
class VARHF(VAR, PyTorchModelHubMixin):
|
304 |
+
# repo_url="https://github.com/FoundationVision/VAR",
|
305 |
+
# tags=["image-generation"]):
|
306 |
+
def __init__(
|
307 |
+
self,
|
308 |
+
vae_kwargs,
|
309 |
+
num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
310 |
+
norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
|
311 |
+
attn_l2_norm=False,
|
312 |
+
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
|
313 |
+
flash_if_available=True, fused_if_available=True,
|
314 |
+
):
|
315 |
+
vae_local = VQVAE(**vae_kwargs)
|
316 |
+
super().__init__(
|
317 |
+
vae_local=vae_local,
|
318 |
+
num_classes=num_classes, depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
|
319 |
+
norm_eps=norm_eps, shared_aln=shared_aln, cond_drop_rate=cond_drop_rate,
|
320 |
+
attn_l2_norm=attn_l2_norm,
|
321 |
+
patch_nums=patch_nums,
|
322 |
+
flash_if_available=flash_if_available, fused_if_available=fused_if_available,
|
323 |
+
)
|
VAR/models/vqvae.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
References:
|
3 |
+
- VectorQuantizer2: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L110
|
4 |
+
- GumbelQuantize: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L213
|
5 |
+
- VQVAE (VQModel): https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/autoencoder.py#L14
|
6 |
+
"""
|
7 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from .basic_vae import Decoder, Encoder
|
13 |
+
from .quant import VectorQuantizer2
|
14 |
+
|
15 |
+
|
16 |
+
class VQVAE(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self, vocab_size=4096, z_channels=32, ch=128, dropout=0.0,
|
19 |
+
beta=0.25, # commitment loss weight
|
20 |
+
using_znorm=False, # whether to normalize when computing the nearest neighbors
|
21 |
+
quant_conv_ks=3, # quant conv kernel size
|
22 |
+
quant_resi=0.5, # 0.5 means \phi(x) = 0.5conv(x) + (1-0.5)x
|
23 |
+
share_quant_resi=4, # use 4 \phi layers for K scales: partially-shared \phi
|
24 |
+
default_qresi_counts=0, # if is 0: automatically set to len(v_patch_nums)
|
25 |
+
v_patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # number of patches for each scale, h_{1 to K} = w_{1 to K} = v_patch_nums[k]
|
26 |
+
test_mode=True,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
self.test_mode = test_mode
|
30 |
+
self.V, self.Cvae = vocab_size, z_channels
|
31 |
+
# ddconfig is copied from https://github.com/CompVis/latent-diffusion/blob/e66308c7f2e64cb581c6d27ab6fbeb846828253b/models/first_stage_models/vq-f16/config.yaml
|
32 |
+
ddconfig = dict(
|
33 |
+
dropout=dropout, ch=ch, z_channels=z_channels,
|
34 |
+
in_channels=3, ch_mult=(1, 1, 2, 2, 4), num_res_blocks=2, # from vq-f16/config.yaml above
|
35 |
+
using_sa=True, using_mid_sa=True, # from vq-f16/config.yaml above
|
36 |
+
# resamp_with_conv=True, # always True, removed.
|
37 |
+
)
|
38 |
+
ddconfig.pop('double_z', None) # only KL-VAE should use double_z=True
|
39 |
+
self.encoder = Encoder(double_z=False, **ddconfig)
|
40 |
+
self.decoder = Decoder(**ddconfig)
|
41 |
+
|
42 |
+
self.vocab_size = vocab_size
|
43 |
+
self.downsample = 2 ** (len(ddconfig['ch_mult'])-1)
|
44 |
+
self.quantize: VectorQuantizer2 = VectorQuantizer2(
|
45 |
+
vocab_size=vocab_size, Cvae=self.Cvae, using_znorm=using_znorm, beta=beta,
|
46 |
+
default_qresi_counts=default_qresi_counts, v_patch_nums=v_patch_nums, quant_resi=quant_resi, share_quant_resi=share_quant_resi,
|
47 |
+
)
|
48 |
+
self.quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)
|
49 |
+
self.post_quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)
|
50 |
+
|
51 |
+
if self.test_mode:
|
52 |
+
self.eval()
|
53 |
+
[p.requires_grad_(False) for p in self.parameters()]
|
54 |
+
|
55 |
+
# ===================== `forward` is only used in VAE training =====================
|
56 |
+
def forward(self, inp, ret_usages=False): # -> rec_B3HW, idx_N, loss
|
57 |
+
VectorQuantizer2.forward
|
58 |
+
f_hat, usages, vq_loss = self.quantize(self.quant_conv(self.encoder(inp)), ret_usages=ret_usages)
|
59 |
+
return self.decoder(self.post_quant_conv(f_hat)), usages, vq_loss
|
60 |
+
# ===================== `forward` is only used in VAE training =====================
|
61 |
+
|
62 |
+
def fhat_to_img(self, f_hat: torch.Tensor):
|
63 |
+
return self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)
|
64 |
+
|
65 |
+
def img_to_idxBl(self, inp_img_no_grad: torch.Tensor, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[torch.LongTensor]: # return List[Bl]
|
66 |
+
f = self.quant_conv(self.encoder(inp_img_no_grad))
|
67 |
+
return self.quantize.f_to_idxBl_or_fhat(f, to_fhat=False, v_patch_nums=v_patch_nums)
|
68 |
+
|
69 |
+
def idxBl_to_img(self, ms_idx_Bl: List[torch.Tensor], same_shape: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
|
70 |
+
B = ms_idx_Bl[0].shape[0]
|
71 |
+
ms_h_BChw = []
|
72 |
+
for idx_Bl in ms_idx_Bl:
|
73 |
+
l = idx_Bl.shape[1]
|
74 |
+
pn = round(l ** 0.5)
|
75 |
+
ms_h_BChw.append(self.quantize.embedding(idx_Bl).transpose(1, 2).view(B, self.Cvae, pn, pn))
|
76 |
+
return self.embed_to_img(ms_h_BChw=ms_h_BChw, all_to_max_scale=same_shape, last_one=last_one)
|
77 |
+
|
78 |
+
def embed_to_img(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
|
79 |
+
if last_one:
|
80 |
+
return self.decoder(self.post_quant_conv(self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=True))).clamp_(-1, 1)
|
81 |
+
else:
|
82 |
+
return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=False)]
|
83 |
+
|
84 |
+
def img_to_reconstructed_img(self, x, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None, last_one=False) -> List[torch.Tensor]:
|
85 |
+
f = self.quant_conv(self.encoder(x))
|
86 |
+
ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat(f, to_fhat=True, v_patch_nums=v_patch_nums)
|
87 |
+
if last_one:
|
88 |
+
return self.decoder(self.post_quant_conv(ls_f_hat_BChw[-1])).clamp_(-1, 1)
|
89 |
+
else:
|
90 |
+
return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in ls_f_hat_BChw]
|
91 |
+
|
92 |
+
def load_state_dict(self, state_dict: Dict[str, Any], strict=True, assign=False):
|
93 |
+
if 'quantize.ema_vocab_hit_SV' in state_dict and state_dict['quantize.ema_vocab_hit_SV'].shape[0] != self.quantize.ema_vocab_hit_SV.shape[0]:
|
94 |
+
state_dict['quantize.ema_vocab_hit_SV'] = self.quantize.ema_vocab_hit_SV
|
95 |
+
return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
|
VAR/requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch~=2.1.0
|
2 |
+
|
3 |
+
Pillow
|
4 |
+
huggingface_hub
|
5 |
+
numpy
|
6 |
+
pytz
|
7 |
+
transformers
|
8 |
+
typed-argument-parser
|
VAR/train.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import sys
|
5 |
+
import time
|
6 |
+
import warnings
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
|
12 |
+
import dist
|
13 |
+
from utils import arg_util, misc
|
14 |
+
from utils.data import build_dataset
|
15 |
+
from utils.data_sampler import DistInfiniteBatchSampler, EvalDistributedSampler
|
16 |
+
from utils.misc import auto_resume
|
17 |
+
|
18 |
+
|
19 |
+
def build_everything(args: arg_util.Args):
|
20 |
+
# resume
|
21 |
+
auto_resume_info, start_ep, start_it, trainer_state, args_state = auto_resume(args, 'ar-ckpt*.pth')
|
22 |
+
# create tensorboard logger
|
23 |
+
tb_lg: misc.TensorboardLogger
|
24 |
+
with_tb_lg = dist.is_master()
|
25 |
+
if with_tb_lg:
|
26 |
+
os.makedirs(args.tb_log_dir_path, exist_ok=True)
|
27 |
+
# noinspection PyTypeChecker
|
28 |
+
tb_lg = misc.DistLogger(misc.TensorboardLogger(log_dir=args.tb_log_dir_path, filename_suffix=f'__{misc.time_str("%m%d_%H%M")}'), verbose=True)
|
29 |
+
tb_lg.flush()
|
30 |
+
else:
|
31 |
+
# noinspection PyTypeChecker
|
32 |
+
tb_lg = misc.DistLogger(None, verbose=False)
|
33 |
+
dist.barrier()
|
34 |
+
|
35 |
+
# log args
|
36 |
+
print(f'global bs={args.glb_batch_size}, local bs={args.batch_size}')
|
37 |
+
print(f'initial args:\n{str(args)}')
|
38 |
+
|
39 |
+
# build data
|
40 |
+
if not args.local_debug:
|
41 |
+
print(f'[build PT data] ...\n')
|
42 |
+
num_classes, dataset_train, dataset_val = build_dataset(
|
43 |
+
args.data_path, final_reso=args.data_load_reso, hflip=args.hflip, mid_reso=args.mid_reso,
|
44 |
+
)
|
45 |
+
types = str((type(dataset_train).__name__, type(dataset_val).__name__))
|
46 |
+
|
47 |
+
ld_val = DataLoader(
|
48 |
+
dataset_val, num_workers=0, pin_memory=True,
|
49 |
+
batch_size=round(args.batch_size*1.5), sampler=EvalDistributedSampler(dataset_val, num_replicas=dist.get_world_size(), rank=dist.get_rank()),
|
50 |
+
shuffle=False, drop_last=False,
|
51 |
+
)
|
52 |
+
del dataset_val
|
53 |
+
|
54 |
+
ld_train = DataLoader(
|
55 |
+
dataset=dataset_train, num_workers=args.workers, pin_memory=True,
|
56 |
+
generator=args.get_different_generator_for_each_rank(), # worker_init_fn=worker_init_fn,
|
57 |
+
batch_sampler=DistInfiniteBatchSampler(
|
58 |
+
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, same_seed_for_all_ranks=args.same_seed_for_all_ranks,
|
59 |
+
shuffle=True, fill_last=True, rank=dist.get_rank(), world_size=dist.get_world_size(), start_ep=start_ep, start_it=start_it,
|
60 |
+
),
|
61 |
+
)
|
62 |
+
del dataset_train
|
63 |
+
|
64 |
+
[print(line) for line in auto_resume_info]
|
65 |
+
print(f'[dataloader multi processing] ...', end='', flush=True)
|
66 |
+
stt = time.time()
|
67 |
+
iters_train = len(ld_train)
|
68 |
+
ld_train = iter(ld_train)
|
69 |
+
# noinspection PyArgumentList
|
70 |
+
print(f' [dataloader multi processing](*) finished! ({time.time()-stt:.2f}s)', flush=True, clean=True)
|
71 |
+
print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size}, iters_train={iters_train}, types(tr, va)={types}')
|
72 |
+
|
73 |
+
else:
|
74 |
+
num_classes = 1000
|
75 |
+
ld_val = ld_train = None
|
76 |
+
iters_train = 10
|
77 |
+
|
78 |
+
# build models
|
79 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
80 |
+
from models import VAR, VQVAE, build_vae_var
|
81 |
+
from trainer import VARTrainer
|
82 |
+
from utils.amp_sc import AmpOptimizer
|
83 |
+
from utils.lr_control import filter_params
|
84 |
+
|
85 |
+
vae_local, var_wo_ddp = build_vae_var(
|
86 |
+
V=4096, Cvae=32, ch=160, share_quant_resi=4, # hard-coded VQVAE hyperparameters
|
87 |
+
device=dist.get_device(), patch_nums=args.patch_nums,
|
88 |
+
num_classes=num_classes, depth=args.depth, shared_aln=args.saln, attn_l2_norm=args.anorm,
|
89 |
+
flash_if_available=args.fuse, fused_if_available=args.fuse,
|
90 |
+
init_adaln=args.aln, init_adaln_gamma=args.alng, init_head=args.hd, init_std=args.ini,
|
91 |
+
)
|
92 |
+
|
93 |
+
vae_ckpt = 'vae_ch160v4096z32.pth'
|
94 |
+
if dist.is_local_master():
|
95 |
+
if not os.path.exists(vae_ckpt):
|
96 |
+
os.system(f'wget https://huggingface.co/FoundationVision/var/resolve/main/{vae_ckpt}')
|
97 |
+
dist.barrier()
|
98 |
+
vae_local.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
|
99 |
+
|
100 |
+
vae_local: VQVAE = args.compile_model(vae_local, args.vfast)
|
101 |
+
var_wo_ddp: VAR = args.compile_model(var_wo_ddp, args.tfast)
|
102 |
+
var: DDP = (DDP if dist.initialized() else NullDDP)(var_wo_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
|
103 |
+
|
104 |
+
print(f'[INIT] VAR model = {var_wo_ddp}\n\n')
|
105 |
+
count_p = lambda m: f'{sum(p.numel() for p in m.parameters())/1e6:.2f}'
|
106 |
+
print(f'[INIT][#para] ' + ', '.join([f'{k}={count_p(m)}' for k, m in (('VAE', vae_local), ('VAE.enc', vae_local.encoder), ('VAE.dec', vae_local.decoder), ('VAE.quant', vae_local.quantize))]))
|
107 |
+
print(f'[INIT][#para] ' + ', '.join([f'{k}={count_p(m)}' for k, m in (('VAR', var_wo_ddp),)]) + '\n\n')
|
108 |
+
|
109 |
+
# build optimizer
|
110 |
+
names, paras, para_groups = filter_params(var_wo_ddp, nowd_keys={
|
111 |
+
'cls_token', 'start_token', 'task_token', 'cfg_uncond',
|
112 |
+
'pos_embed', 'pos_1LC', 'pos_start', 'start_pos', 'lvl_embed',
|
113 |
+
'gamma', 'beta',
|
114 |
+
'ada_gss', 'moe_bias',
|
115 |
+
'scale_mul',
|
116 |
+
})
|
117 |
+
opt_clz = {
|
118 |
+
'adam': partial(torch.optim.AdamW, betas=(0.9, 0.95), fused=args.afuse),
|
119 |
+
'adamw': partial(torch.optim.AdamW, betas=(0.9, 0.95), fused=args.afuse),
|
120 |
+
}[args.opt.lower().strip()]
|
121 |
+
opt_kw = dict(lr=args.tlr, weight_decay=0)
|
122 |
+
print(f'[INIT] optim={opt_clz}, opt_kw={opt_kw}\n')
|
123 |
+
|
124 |
+
var_optim = AmpOptimizer(
|
125 |
+
mixed_precision=args.fp16, optimizer=opt_clz(params=para_groups, **opt_kw), names=names, paras=paras,
|
126 |
+
grad_clip=args.tclip, n_gradient_accumulation=args.ac
|
127 |
+
)
|
128 |
+
del names, paras, para_groups
|
129 |
+
|
130 |
+
# build trainer
|
131 |
+
trainer = VARTrainer(
|
132 |
+
device=args.device, patch_nums=args.patch_nums, resos=args.resos,
|
133 |
+
vae_local=vae_local, var_wo_ddp=var_wo_ddp, var=var,
|
134 |
+
var_opt=var_optim, label_smooth=args.ls,
|
135 |
+
)
|
136 |
+
if trainer_state is not None and len(trainer_state):
|
137 |
+
trainer.load_state_dict(trainer_state, strict=False, skip_vae=True) # don't load vae again
|
138 |
+
del vae_local, var_wo_ddp, var, var_optim
|
139 |
+
|
140 |
+
if args.local_debug:
|
141 |
+
rng = torch.Generator('cpu')
|
142 |
+
rng.manual_seed(0)
|
143 |
+
B = 4
|
144 |
+
inp = torch.rand(B, 3, args.data_load_reso, args.data_load_reso)
|
145 |
+
label = torch.ones(B, dtype=torch.long)
|
146 |
+
|
147 |
+
me = misc.MetricLogger(delimiter=' ')
|
148 |
+
trainer.train_step(
|
149 |
+
it=0, g_it=0, stepping=True, metric_lg=me, tb_lg=tb_lg,
|
150 |
+
inp_B3HW=inp, label_B=label, prog_si=args.pg0, prog_wp_it=20,
|
151 |
+
)
|
152 |
+
trainer.load_state_dict(trainer.state_dict())
|
153 |
+
trainer.train_step(
|
154 |
+
it=99, g_it=599, stepping=True, metric_lg=me, tb_lg=tb_lg,
|
155 |
+
inp_B3HW=inp, label_B=label, prog_si=-1, prog_wp_it=20,
|
156 |
+
)
|
157 |
+
print({k: meter.global_avg for k, meter in me.meters.items()})
|
158 |
+
|
159 |
+
args.dump_log(); tb_lg.flush(); tb_lg.close()
|
160 |
+
if isinstance(sys.stdout, misc.SyncPrint) and isinstance(sys.stderr, misc.SyncPrint):
|
161 |
+
sys.stdout.close(), sys.stderr.close()
|
162 |
+
exit(0)
|
163 |
+
|
164 |
+
dist.barrier()
|
165 |
+
return (
|
166 |
+
tb_lg, trainer, start_ep, start_it,
|
167 |
+
iters_train, ld_train, ld_val
|
168 |
+
)
|
169 |
+
|
170 |
+
|
171 |
+
def main_training():
|
172 |
+
args: arg_util.Args = arg_util.init_dist_and_get_args()
|
173 |
+
if args.local_debug:
|
174 |
+
torch.autograd.set_detect_anomaly(True)
|
175 |
+
|
176 |
+
(
|
177 |
+
tb_lg, trainer,
|
178 |
+
start_ep, start_it,
|
179 |
+
iters_train, ld_train, ld_val
|
180 |
+
) = build_everything(args)
|
181 |
+
|
182 |
+
# train
|
183 |
+
start_time = time.time()
|
184 |
+
best_L_mean, best_L_tail, best_acc_mean, best_acc_tail = 999., 999., -1., -1.
|
185 |
+
best_val_loss_mean, best_val_loss_tail, best_val_acc_mean, best_val_acc_tail = 999, 999, -1, -1
|
186 |
+
|
187 |
+
L_mean, L_tail = -1, -1
|
188 |
+
for ep in range(start_ep, args.ep):
|
189 |
+
if hasattr(ld_train, 'sampler') and hasattr(ld_train.sampler, 'set_epoch'):
|
190 |
+
ld_train.sampler.set_epoch(ep)
|
191 |
+
if ep < 3:
|
192 |
+
# noinspection PyArgumentList
|
193 |
+
print(f'[{type(ld_train).__name__}] [ld_train.sampler.set_epoch({ep})]', flush=True, force=True)
|
194 |
+
tb_lg.set_step(ep * iters_train)
|
195 |
+
|
196 |
+
stats, (sec, remain_time, finish_time) = train_one_ep(
|
197 |
+
ep, ep == start_ep, start_it if ep == start_ep else 0, args, tb_lg, ld_train, iters_train, trainer
|
198 |
+
)
|
199 |
+
|
200 |
+
L_mean, L_tail, acc_mean, acc_tail, grad_norm = stats['Lm'], stats['Lt'], stats['Accm'], stats['Acct'], stats['tnm']
|
201 |
+
best_L_mean, best_acc_mean = min(best_L_mean, L_mean), max(best_acc_mean, acc_mean)
|
202 |
+
if L_tail != -1: best_L_tail, best_acc_tail = min(best_L_tail, L_tail), max(best_acc_tail, acc_tail)
|
203 |
+
args.L_mean, args.L_tail, args.acc_mean, args.acc_tail, args.grad_norm = L_mean, L_tail, acc_mean, acc_tail, grad_norm
|
204 |
+
args.cur_ep = f'{ep+1}/{args.ep}'
|
205 |
+
args.remain_time, args.finish_time = remain_time, finish_time
|
206 |
+
|
207 |
+
AR_ep_loss = dict(L_mean=L_mean, L_tail=L_tail, acc_mean=acc_mean, acc_tail=acc_tail)
|
208 |
+
is_val_and_also_saving = (ep + 1) % 10 == 0 or (ep + 1) == args.ep
|
209 |
+
if is_val_and_also_saving:
|
210 |
+
val_loss_mean, val_loss_tail, val_acc_mean, val_acc_tail, tot, cost = trainer.eval_ep(ld_val)
|
211 |
+
best_updated = best_val_loss_tail > val_loss_tail
|
212 |
+
best_val_loss_mean, best_val_loss_tail = min(best_val_loss_mean, val_loss_mean), min(best_val_loss_tail, val_loss_tail)
|
213 |
+
best_val_acc_mean, best_val_acc_tail = max(best_val_acc_mean, val_acc_mean), max(best_val_acc_tail, val_acc_tail)
|
214 |
+
AR_ep_loss.update(vL_mean=val_loss_mean, vL_tail=val_loss_tail, vacc_mean=val_acc_mean, vacc_tail=val_acc_tail)
|
215 |
+
args.vL_mean, args.vL_tail, args.vacc_mean, args.vacc_tail = val_loss_mean, val_loss_tail, val_acc_mean, val_acc_tail
|
216 |
+
print(f' [*] [ep{ep}] (val {tot}) Lm: {L_mean:.4f}, Lt: {L_tail:.4f}, Acc m&t: {acc_mean:.2f} {acc_tail:.2f}, Val cost: {cost:.2f}s')
|
217 |
+
|
218 |
+
if dist.is_local_master():
|
219 |
+
local_out_ckpt = os.path.join(args.local_out_dir_path, 'ar-ckpt-last.pth')
|
220 |
+
local_out_ckpt_best = os.path.join(args.local_out_dir_path, 'ar-ckpt-best.pth')
|
221 |
+
print(f'[saving ckpt] ...', end='', flush=True)
|
222 |
+
torch.save({
|
223 |
+
'epoch': ep+1,
|
224 |
+
'iter': 0,
|
225 |
+
'trainer': trainer.state_dict(),
|
226 |
+
'args': args.state_dict(),
|
227 |
+
}, local_out_ckpt)
|
228 |
+
if best_updated:
|
229 |
+
shutil.copy(local_out_ckpt, local_out_ckpt_best)
|
230 |
+
print(f' [saving ckpt](*) finished! @ {local_out_ckpt}', flush=True, clean=True)
|
231 |
+
dist.barrier()
|
232 |
+
|
233 |
+
print( f' [ep{ep}] (training ) Lm: {best_L_mean:.3f} ({L_mean:.3f}), Lt: {best_L_tail:.3f} ({L_tail:.3f}), Acc m&t: {best_acc_mean:.2f} {best_acc_tail:.2f}, Remain: {remain_time}, Finish: {finish_time}', flush=True)
|
234 |
+
tb_lg.update(head='AR_ep_loss', step=ep+1, **AR_ep_loss)
|
235 |
+
tb_lg.update(head='AR_z_burnout', step=ep+1, rest_hours=round(sec / 60 / 60, 2))
|
236 |
+
args.dump_log(); tb_lg.flush()
|
237 |
+
|
238 |
+
total_time = f'{(time.time() - start_time) / 60 / 60:.1f}h'
|
239 |
+
print('\n\n')
|
240 |
+
print(f' [*] [PT finished] Total cost: {total_time}, Lm: {best_L_mean:.3f} ({L_mean}), Lt: {best_L_tail:.3f} ({L_tail})')
|
241 |
+
print('\n\n')
|
242 |
+
|
243 |
+
del stats
|
244 |
+
del iters_train, ld_train
|
245 |
+
time.sleep(3), gc.collect(), torch.cuda.empty_cache(), time.sleep(3)
|
246 |
+
|
247 |
+
args.remain_time, args.finish_time = '-', time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() - 60))
|
248 |
+
print(f'final args:\n\n{str(args)}')
|
249 |
+
args.dump_log(); tb_lg.flush(); tb_lg.close()
|
250 |
+
dist.barrier()
|
251 |
+
|
252 |
+
|
253 |
+
def train_one_ep(ep: int, is_first_ep: bool, start_it: int, args: arg_util.Args, tb_lg: misc.TensorboardLogger, ld_or_itrt, iters_train: int, trainer):
|
254 |
+
# import heavy packages after Dataloader object creation
|
255 |
+
from trainer import VARTrainer
|
256 |
+
from utils.lr_control import lr_wd_annealing
|
257 |
+
trainer: VARTrainer
|
258 |
+
|
259 |
+
step_cnt = 0
|
260 |
+
me = misc.MetricLogger(delimiter=' ')
|
261 |
+
me.add_meter('tlr', misc.SmoothedValue(window_size=1, fmt='{value:.2g}'))
|
262 |
+
me.add_meter('tnm', misc.SmoothedValue(window_size=1, fmt='{value:.2f}'))
|
263 |
+
[me.add_meter(x, misc.SmoothedValue(fmt='{median:.3f} ({global_avg:.3f})')) for x in ['Lm', 'Lt']]
|
264 |
+
[me.add_meter(x, misc.SmoothedValue(fmt='{median:.2f} ({global_avg:.2f})')) for x in ['Accm', 'Acct']]
|
265 |
+
header = f'[Ep]: [{ep:4d}/{args.ep}]'
|
266 |
+
|
267 |
+
if is_first_ep:
|
268 |
+
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
269 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
270 |
+
g_it, max_it = ep * iters_train, args.ep * iters_train
|
271 |
+
|
272 |
+
for it, (inp, label) in me.log_every(start_it, iters_train, ld_or_itrt, 30 if iters_train > 8000 else 5, header):
|
273 |
+
g_it = ep * iters_train + it
|
274 |
+
if it < start_it: continue
|
275 |
+
if is_first_ep and it == start_it: warnings.resetwarnings()
|
276 |
+
|
277 |
+
inp = inp.to(args.device, non_blocking=True)
|
278 |
+
label = label.to(args.device, non_blocking=True)
|
279 |
+
|
280 |
+
args.cur_it = f'{it+1}/{iters_train}'
|
281 |
+
|
282 |
+
wp_it = args.wp * iters_train
|
283 |
+
min_tlr, max_tlr, min_twd, max_twd = lr_wd_annealing(args.sche, trainer.var_opt.optimizer, args.tlr, args.twd, args.twde, g_it, wp_it, max_it, wp0=args.wp0, wpe=args.wpe)
|
284 |
+
args.cur_lr, args.cur_wd = max_tlr, max_twd
|
285 |
+
|
286 |
+
if args.pg: # default: args.pg == 0.0, means no progressive training, won't get into this
|
287 |
+
if g_it <= wp_it: prog_si = args.pg0
|
288 |
+
elif g_it >= max_it*args.pg: prog_si = len(args.patch_nums) - 1
|
289 |
+
else:
|
290 |
+
delta = len(args.patch_nums) - 1 - args.pg0
|
291 |
+
progress = min(max((g_it - wp_it) / (max_it*args.pg - wp_it), 0), 1) # from 0 to 1
|
292 |
+
prog_si = args.pg0 + round(progress * delta) # from args.pg0 to len(args.patch_nums)-1
|
293 |
+
else:
|
294 |
+
prog_si = -1
|
295 |
+
|
296 |
+
stepping = (g_it + 1) % args.ac == 0
|
297 |
+
step_cnt += int(stepping)
|
298 |
+
|
299 |
+
grad_norm, scale_log2 = trainer.train_step(
|
300 |
+
it=it, g_it=g_it, stepping=stepping, metric_lg=me, tb_lg=tb_lg,
|
301 |
+
inp_B3HW=inp, label_B=label, prog_si=prog_si, prog_wp_it=args.pgwp * iters_train,
|
302 |
+
)
|
303 |
+
|
304 |
+
me.update(tlr=max_tlr)
|
305 |
+
tb_lg.set_step(step=g_it)
|
306 |
+
tb_lg.update(head='AR_opt_lr/lr_min', sche_tlr=min_tlr)
|
307 |
+
tb_lg.update(head='AR_opt_lr/lr_max', sche_tlr=max_tlr)
|
308 |
+
tb_lg.update(head='AR_opt_wd/wd_max', sche_twd=max_twd)
|
309 |
+
tb_lg.update(head='AR_opt_wd/wd_min', sche_twd=min_twd)
|
310 |
+
tb_lg.update(head='AR_opt_grad/fp16', scale_log2=scale_log2)
|
311 |
+
|
312 |
+
if args.tclip > 0:
|
313 |
+
tb_lg.update(head='AR_opt_grad/grad', grad_norm=grad_norm)
|
314 |
+
tb_lg.update(head='AR_opt_grad/grad', grad_clip=args.tclip)
|
315 |
+
|
316 |
+
me.synchronize_between_processes()
|
317 |
+
return {k: meter.global_avg for k, meter in me.meters.items()}, me.iter_time.time_preds(max_it - (g_it + 1) + (args.ep - ep) * 15) # +15: other cost
|
318 |
+
|
319 |
+
|
320 |
+
class NullDDP(torch.nn.Module):
|
321 |
+
def __init__(self, module, *args, **kwargs):
|
322 |
+
super(NullDDP, self).__init__()
|
323 |
+
self.module = module
|
324 |
+
self.require_backward_grad_sync = False
|
325 |
+
|
326 |
+
def forward(self, *args, **kwargs):
|
327 |
+
return self.module(*args, **kwargs)
|
328 |
+
|
329 |
+
|
330 |
+
if __name__ == '__main__':
|
331 |
+
try: main_training()
|
332 |
+
finally:
|
333 |
+
dist.finalize()
|
334 |
+
if isinstance(sys.stdout, misc.SyncPrint) and isinstance(sys.stderr, misc.SyncPrint):
|
335 |
+
sys.stdout.close(), sys.stderr.close()
|
VAR/trainer.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
|
9 |
+
import dist
|
10 |
+
from models import VAR, VQVAE, VectorQuantizer2
|
11 |
+
from utils.amp_sc import AmpOptimizer
|
12 |
+
from utils.misc import MetricLogger, TensorboardLogger
|
13 |
+
|
14 |
+
Ten = torch.Tensor
|
15 |
+
FTen = torch.Tensor
|
16 |
+
ITen = torch.LongTensor
|
17 |
+
BTen = torch.BoolTensor
|
18 |
+
|
19 |
+
|
20 |
+
class VARTrainer(object):
|
21 |
+
def __init__(
|
22 |
+
self, device, patch_nums: Tuple[int, ...], resos: Tuple[int, ...],
|
23 |
+
vae_local: VQVAE, var_wo_ddp: VAR, var: DDP,
|
24 |
+
var_opt: AmpOptimizer, label_smooth: float,
|
25 |
+
):
|
26 |
+
super(VARTrainer, self).__init__()
|
27 |
+
|
28 |
+
self.var, self.vae_local, self.quantize_local = var, vae_local, vae_local.quantize
|
29 |
+
self.quantize_local: VectorQuantizer2
|
30 |
+
self.var_wo_ddp: VAR = var_wo_ddp # after torch.compile
|
31 |
+
self.var_opt = var_opt
|
32 |
+
|
33 |
+
del self.var_wo_ddp.rng
|
34 |
+
self.var_wo_ddp.rng = torch.Generator(device=device)
|
35 |
+
|
36 |
+
self.label_smooth = label_smooth
|
37 |
+
self.train_loss = nn.CrossEntropyLoss(label_smoothing=label_smooth, reduction='none')
|
38 |
+
self.val_loss = nn.CrossEntropyLoss(label_smoothing=0.0, reduction='mean')
|
39 |
+
self.L = sum(pn * pn for pn in patch_nums)
|
40 |
+
self.last_l = patch_nums[-1] * patch_nums[-1]
|
41 |
+
self.loss_weight = torch.ones(1, self.L, device=device) / self.L
|
42 |
+
|
43 |
+
self.patch_nums, self.resos = patch_nums, resos
|
44 |
+
self.begin_ends = []
|
45 |
+
cur = 0
|
46 |
+
for i, pn in enumerate(patch_nums):
|
47 |
+
self.begin_ends.append((cur, cur + pn * pn))
|
48 |
+
cur += pn*pn
|
49 |
+
|
50 |
+
self.prog_it = 0
|
51 |
+
self.last_prog_si = -1
|
52 |
+
self.first_prog = True
|
53 |
+
|
54 |
+
@torch.no_grad()
|
55 |
+
def eval_ep(self, ld_val: DataLoader):
|
56 |
+
tot = 0
|
57 |
+
L_mean, L_tail, acc_mean, acc_tail = 0, 0, 0, 0
|
58 |
+
stt = time.time()
|
59 |
+
training = self.var_wo_ddp.training
|
60 |
+
self.var_wo_ddp.eval()
|
61 |
+
for inp_B3HW, label_B in ld_val:
|
62 |
+
B, V = label_B.shape[0], self.vae_local.vocab_size
|
63 |
+
inp_B3HW = inp_B3HW.to(dist.get_device(), non_blocking=True)
|
64 |
+
label_B = label_B.to(dist.get_device(), non_blocking=True)
|
65 |
+
|
66 |
+
gt_idx_Bl: List[ITen] = self.vae_local.img_to_idxBl(inp_B3HW)
|
67 |
+
gt_BL = torch.cat(gt_idx_Bl, dim=1)
|
68 |
+
x_BLCv_wo_first_l: Ten = self.quantize_local.idxBl_to_var_input(gt_idx_Bl)
|
69 |
+
|
70 |
+
self.var_wo_ddp.forward
|
71 |
+
logits_BLV = self.var_wo_ddp(label_B, x_BLCv_wo_first_l)
|
72 |
+
L_mean += self.val_loss(logits_BLV.data.view(-1, V), gt_BL.view(-1)) * B
|
73 |
+
L_tail += self.val_loss(logits_BLV.data[:, -self.last_l:].reshape(-1, V), gt_BL[:, -self.last_l:].reshape(-1)) * B
|
74 |
+
acc_mean += (logits_BLV.data.argmax(dim=-1) == gt_BL).sum() * (100/gt_BL.shape[1])
|
75 |
+
acc_tail += (logits_BLV.data[:, -self.last_l:].argmax(dim=-1) == gt_BL[:, -self.last_l:]).sum() * (100 / self.last_l)
|
76 |
+
tot += B
|
77 |
+
self.var_wo_ddp.train(training)
|
78 |
+
|
79 |
+
stats = L_mean.new_tensor([L_mean.item(), L_tail.item(), acc_mean.item(), acc_tail.item(), tot])
|
80 |
+
dist.allreduce(stats)
|
81 |
+
tot = round(stats[-1].item())
|
82 |
+
stats /= tot
|
83 |
+
L_mean, L_tail, acc_mean, acc_tail, _ = stats.tolist()
|
84 |
+
return L_mean, L_tail, acc_mean, acc_tail, tot, time.time()-stt
|
85 |
+
|
86 |
+
def train_step(
|
87 |
+
self, it: int, g_it: int, stepping: bool, metric_lg: MetricLogger, tb_lg: TensorboardLogger,
|
88 |
+
inp_B3HW: FTen, label_B: Union[ITen, FTen], prog_si: int, prog_wp_it: float,
|
89 |
+
) -> Tuple[Optional[Union[Ten, float]], Optional[float]]:
|
90 |
+
# if progressive training
|
91 |
+
self.var_wo_ddp.prog_si = self.vae_local.quantize.prog_si = prog_si
|
92 |
+
if self.last_prog_si != prog_si:
|
93 |
+
if self.last_prog_si != -1: self.first_prog = False
|
94 |
+
self.last_prog_si = prog_si
|
95 |
+
self.prog_it = 0
|
96 |
+
self.prog_it += 1
|
97 |
+
prog_wp = max(min(self.prog_it / prog_wp_it, 1), 0.01)
|
98 |
+
if self.first_prog: prog_wp = 1 # no prog warmup at first prog stage, as it's already solved in wp
|
99 |
+
if prog_si == len(self.patch_nums) - 1: prog_si = -1 # max prog, as if no prog
|
100 |
+
|
101 |
+
# forward
|
102 |
+
B, V = label_B.shape[0], self.vae_local.vocab_size
|
103 |
+
self.var.require_backward_grad_sync = stepping
|
104 |
+
|
105 |
+
gt_idx_Bl: List[ITen] = self.vae_local.img_to_idxBl(inp_B3HW)
|
106 |
+
gt_BL = torch.cat(gt_idx_Bl, dim=1)
|
107 |
+
x_BLCv_wo_first_l: Ten = self.quantize_local.idxBl_to_var_input(gt_idx_Bl)
|
108 |
+
|
109 |
+
with self.var_opt.amp_ctx:
|
110 |
+
self.var_wo_ddp.forward
|
111 |
+
logits_BLV = self.var(label_B, x_BLCv_wo_first_l)
|
112 |
+
loss = self.train_loss(logits_BLV.view(-1, V), gt_BL.view(-1)).view(B, -1)
|
113 |
+
if prog_si >= 0: # in progressive training
|
114 |
+
bg, ed = self.begin_ends[prog_si]
|
115 |
+
assert logits_BLV.shape[1] == gt_BL.shape[1] == ed
|
116 |
+
lw = self.loss_weight[:, :ed].clone()
|
117 |
+
lw[:, bg:ed] *= min(max(prog_wp, 0), 1)
|
118 |
+
else: # not in progressive training
|
119 |
+
lw = self.loss_weight
|
120 |
+
loss = loss.mul(lw).sum(dim=-1).mean()
|
121 |
+
|
122 |
+
# backward
|
123 |
+
grad_norm, scale_log2 = self.var_opt.backward_clip_step(loss=loss, stepping=stepping)
|
124 |
+
|
125 |
+
# log
|
126 |
+
pred_BL = logits_BLV.data.argmax(dim=-1)
|
127 |
+
if it == 0 or it in metric_lg.log_iters:
|
128 |
+
Lmean = self.val_loss(logits_BLV.data.view(-1, V), gt_BL.view(-1)).item()
|
129 |
+
acc_mean = (pred_BL == gt_BL).float().mean().item() * 100
|
130 |
+
if prog_si >= 0: # in progressive training
|
131 |
+
Ltail = acc_tail = -1
|
132 |
+
else: # not in progressive training
|
133 |
+
Ltail = self.val_loss(logits_BLV.data[:, -self.last_l:].reshape(-1, V), gt_BL[:, -self.last_l:].reshape(-1)).item()
|
134 |
+
acc_tail = (pred_BL[:, -self.last_l:] == gt_BL[:, -self.last_l:]).float().mean().item() * 100
|
135 |
+
grad_norm = grad_norm.item()
|
136 |
+
metric_lg.update(Lm=Lmean, Lt=Ltail, Accm=acc_mean, Acct=acc_tail, tnm=grad_norm)
|
137 |
+
|
138 |
+
# log to tensorboard
|
139 |
+
if g_it == 0 or (g_it + 1) % 500 == 0:
|
140 |
+
prob_per_class_is_chosen = pred_BL.view(-1).bincount(minlength=V).float()
|
141 |
+
dist.allreduce(prob_per_class_is_chosen)
|
142 |
+
prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
|
143 |
+
cluster_usage = (prob_per_class_is_chosen > 0.001 / V).float().mean().item() * 100
|
144 |
+
if dist.is_master():
|
145 |
+
if g_it == 0:
|
146 |
+
tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-10000)
|
147 |
+
tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-1000)
|
148 |
+
kw = dict(z_voc_usage=cluster_usage)
|
149 |
+
for si, (bg, ed) in enumerate(self.begin_ends):
|
150 |
+
if 0 <= prog_si < si: break
|
151 |
+
pred, tar = logits_BLV.data[:, bg:ed].reshape(-1, V), gt_BL[:, bg:ed].reshape(-1)
|
152 |
+
acc = (pred.argmax(dim=-1) == tar).float().mean().item() * 100
|
153 |
+
ce = self.val_loss(pred, tar).item()
|
154 |
+
kw[f'acc_{self.resos[si]}'] = acc
|
155 |
+
kw[f'L_{self.resos[si]}'] = ce
|
156 |
+
tb_lg.update(head='AR_iter_loss', **kw, step=g_it)
|
157 |
+
tb_lg.update(head='AR_iter_schedule', prog_a_reso=self.resos[prog_si], prog_si=prog_si, prog_wp=prog_wp, step=g_it)
|
158 |
+
|
159 |
+
self.var_wo_ddp.prog_si = self.vae_local.quantize.prog_si = -1
|
160 |
+
return grad_norm, scale_log2
|
161 |
+
|
162 |
+
def get_config(self):
|
163 |
+
return {
|
164 |
+
'patch_nums': self.patch_nums, 'resos': self.resos,
|
165 |
+
'label_smooth': self.label_smooth,
|
166 |
+
'prog_it': self.prog_it, 'last_prog_si': self.last_prog_si, 'first_prog': self.first_prog,
|
167 |
+
}
|
168 |
+
|
169 |
+
def state_dict(self):
|
170 |
+
state = {'config': self.get_config()}
|
171 |
+
for k in ('var_wo_ddp', 'vae_local', 'var_opt'):
|
172 |
+
m = getattr(self, k)
|
173 |
+
if m is not None:
|
174 |
+
if hasattr(m, '_orig_mod'):
|
175 |
+
m = m._orig_mod
|
176 |
+
state[k] = m.state_dict()
|
177 |
+
return state
|
178 |
+
|
179 |
+
def load_state_dict(self, state, strict=True, skip_vae=False):
|
180 |
+
for k in ('var_wo_ddp', 'vae_local', 'var_opt'):
|
181 |
+
if skip_vae and 'vae' in k: continue
|
182 |
+
m = getattr(self, k)
|
183 |
+
if m is not None:
|
184 |
+
if hasattr(m, '_orig_mod'):
|
185 |
+
m = m._orig_mod
|
186 |
+
ret = m.load_state_dict(state[k], strict=strict)
|
187 |
+
if ret is not None:
|
188 |
+
missing, unexpected = ret
|
189 |
+
print(f'[VARTrainer.load_state_dict] {k} missing: {missing}')
|
190 |
+
print(f'[VARTrainer.load_state_dict] {k} unexpected: {unexpected}')
|
191 |
+
|
192 |
+
config: dict = state.pop('config', None)
|
193 |
+
self.prog_it = config.get('prog_it', 0)
|
194 |
+
self.last_prog_si = config.get('last_prog_si', -1)
|
195 |
+
self.first_prog = config.get('first_prog', True)
|
196 |
+
if config is not None:
|
197 |
+
for k, v in self.get_config().items():
|
198 |
+
if config.get(k, None) != v:
|
199 |
+
err = f'[VAR.load_state_dict] config mismatch: this.{k}={v} (ckpt.{k}={config.get(k, None)})'
|
200 |
+
if strict: raise AttributeError(err)
|
201 |
+
else: print(err)
|
VAR/utils/amp_sc.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class NullCtx:
|
8 |
+
def __enter__(self):
|
9 |
+
pass
|
10 |
+
|
11 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
class AmpOptimizer:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
mixed_precision: int,
|
19 |
+
optimizer: torch.optim.Optimizer, names: List[str], paras: List[torch.nn.Parameter],
|
20 |
+
grad_clip: float, n_gradient_accumulation: int = 1,
|
21 |
+
):
|
22 |
+
self.enable_amp = mixed_precision > 0
|
23 |
+
self.using_fp16_rather_bf16 = mixed_precision == 1
|
24 |
+
|
25 |
+
if self.enable_amp:
|
26 |
+
self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=True)
|
27 |
+
self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000) if self.using_fp16_rather_bf16 else None # only fp16 needs a scaler
|
28 |
+
else:
|
29 |
+
self.amp_ctx = NullCtx()
|
30 |
+
self.scaler = None
|
31 |
+
|
32 |
+
self.optimizer, self.names, self.paras = optimizer, names, paras # paras have been filtered so everyone requires grad
|
33 |
+
self.grad_clip = grad_clip
|
34 |
+
self.early_clipping = self.grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm')
|
35 |
+
self.late_clipping = self.grad_clip > 0 and hasattr(optimizer, 'global_grad_norm')
|
36 |
+
|
37 |
+
self.r_accu = 1 / n_gradient_accumulation # r_accu == 1.0 / n_gradient_accumulation
|
38 |
+
|
39 |
+
def backward_clip_step(
|
40 |
+
self, stepping: bool, loss: torch.Tensor,
|
41 |
+
) -> Tuple[Optional[Union[torch.Tensor, float]], Optional[float]]:
|
42 |
+
# backward
|
43 |
+
loss = loss.mul(self.r_accu) # r_accu == 1.0 / n_gradient_accumulation
|
44 |
+
orig_norm = scaler_sc = None
|
45 |
+
if self.scaler is not None:
|
46 |
+
self.scaler.scale(loss).backward(retain_graph=False, create_graph=False)
|
47 |
+
else:
|
48 |
+
loss.backward(retain_graph=False, create_graph=False)
|
49 |
+
|
50 |
+
if stepping:
|
51 |
+
if self.scaler is not None: self.scaler.unscale_(self.optimizer)
|
52 |
+
if self.early_clipping:
|
53 |
+
orig_norm = torch.nn.utils.clip_grad_norm_(self.paras, self.grad_clip)
|
54 |
+
|
55 |
+
if self.scaler is not None:
|
56 |
+
self.scaler.step(self.optimizer)
|
57 |
+
scaler_sc: float = self.scaler.get_scale()
|
58 |
+
if scaler_sc > 32768.: # fp16 will overflow when >65536, so multiply 32768 could be dangerous
|
59 |
+
self.scaler.update(new_scale=32768.)
|
60 |
+
else:
|
61 |
+
self.scaler.update()
|
62 |
+
try:
|
63 |
+
scaler_sc = float(math.log2(scaler_sc))
|
64 |
+
except Exception as e:
|
65 |
+
print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True)
|
66 |
+
raise e
|
67 |
+
else:
|
68 |
+
self.optimizer.step()
|
69 |
+
|
70 |
+
if self.late_clipping:
|
71 |
+
orig_norm = self.optimizer.global_grad_norm
|
72 |
+
|
73 |
+
self.optimizer.zero_grad(set_to_none=True)
|
74 |
+
|
75 |
+
return orig_norm, scaler_sc
|
76 |
+
|
77 |
+
def state_dict(self):
|
78 |
+
return {
|
79 |
+
'optimizer': self.optimizer.state_dict()
|
80 |
+
} if self.scaler is None else {
|
81 |
+
'scaler': self.scaler.state_dict(),
|
82 |
+
'optimizer': self.optimizer.state_dict()
|
83 |
+
}
|
84 |
+
|
85 |
+
def load_state_dict(self, state, strict=True):
|
86 |
+
if self.scaler is not None:
|
87 |
+
try: self.scaler.load_state_dict(state['scaler'])
|
88 |
+
except Exception as e: print(f'[fp16 load_state_dict err] {e}')
|
89 |
+
self.optimizer.load_state_dict(state['optimizer'])
|
VAR/utils/arg_util.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import re
|
5 |
+
import subprocess
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
from collections import OrderedDict
|
9 |
+
from typing import Optional, Union
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
try:
|
15 |
+
from tap import Tap
|
16 |
+
except ImportError as e:
|
17 |
+
print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True)
|
18 |
+
print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True)
|
19 |
+
time.sleep(5)
|
20 |
+
raise e
|
21 |
+
|
22 |
+
import dist
|
23 |
+
|
24 |
+
|
25 |
+
class Args(Tap):
|
26 |
+
data_path: str = '/path/to/imagenet'
|
27 |
+
exp_name: str = 'text'
|
28 |
+
|
29 |
+
# VAE
|
30 |
+
vfast: int = 0 # torch.compile VAE; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'
|
31 |
+
# VAR
|
32 |
+
tfast: int = 0 # torch.compile VAR; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'
|
33 |
+
depth: int = 16 # VAR depth
|
34 |
+
# VAR initialization
|
35 |
+
ini: float = -1 # -1: automated model parameter initialization
|
36 |
+
hd: float = 0.02 # head.w *= hd
|
37 |
+
aln: float = 0.5 # the multiplier of ada_lin.w's initialization
|
38 |
+
alng: float = 1e-5 # the multiplier of ada_lin.w[gamma channels]'s initialization
|
39 |
+
# VAR optimization
|
40 |
+
fp16: int = 0 # 1: using fp16, 2: bf16
|
41 |
+
tblr: float = 1e-4 # base lr
|
42 |
+
tlr: float = None # lr = base lr * (bs / 256)
|
43 |
+
twd: float = 0.05 # initial wd
|
44 |
+
twde: float = 0 # final wd, =twde or twd
|
45 |
+
tclip: float = 2. # <=0 for not using grad clip
|
46 |
+
ls: float = 0.0 # label smooth
|
47 |
+
|
48 |
+
bs: int = 768 # global batch size
|
49 |
+
batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size() / 8) * 8
|
50 |
+
glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size()
|
51 |
+
ac: int = 1 # gradient accumulation
|
52 |
+
|
53 |
+
ep: int = 250
|
54 |
+
wp: float = 0
|
55 |
+
wp0: float = 0.005 # initial lr ratio at the begging of lr warm up
|
56 |
+
wpe: float = 0.01 # final lr ratio at the end of training
|
57 |
+
sche: str = 'lin0' # lr schedule
|
58 |
+
|
59 |
+
opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5 (0.25x) wd=0.8 (8x); Lion needs a large bs to work
|
60 |
+
afuse: bool = True # fused adamw
|
61 |
+
|
62 |
+
# other hps
|
63 |
+
saln: bool = False # whether to use shared adaln
|
64 |
+
anorm: bool = True # whether to use L2 normalized attention
|
65 |
+
fuse: bool = True # whether to use fused op like flash attn, xformers, fused MLP, fused LayerNorm, etc.
|
66 |
+
|
67 |
+
# data
|
68 |
+
pn: str = '1_2_3_4_5_6_8_10_13_16'
|
69 |
+
patch_size: int = 16
|
70 |
+
patch_nums: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_')))
|
71 |
+
resos: tuple = None # [automatically set; don't specify this] = tuple(pn * args.patch_size for pn in args.patch_nums)
|
72 |
+
|
73 |
+
data_load_reso: int = None # [automatically set; don't specify this] would be max(patch_nums) * patch_size
|
74 |
+
mid_reso: float = 1.125 # aug: first resize to mid_reso = 1.125 * data_load_reso, then crop to data_load_reso
|
75 |
+
hflip: bool = False # augmentation: horizontal flip
|
76 |
+
workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader
|
77 |
+
|
78 |
+
# progressive training
|
79 |
+
pg: float = 0.0 # >0 for use progressive training during [0%, this] of training
|
80 |
+
pg0: int = 4 # progressive initial stage, 0: from the 1st token map, 1: from the 2nd token map, etc
|
81 |
+
pgwp: float = 0 # num of warmup epochs at each progressive stage
|
82 |
+
|
83 |
+
# would be automatically set in runtime
|
84 |
+
cmd: str = ' '.join(sys.argv[1:]) # [automatically set; don't specify this]
|
85 |
+
branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
|
86 |
+
commit_id: str = subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
|
87 |
+
commit_msg: str = (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this]
|
88 |
+
acc_mean: float = None # [automatically set; don't specify this]
|
89 |
+
acc_tail: float = None # [automatically set; don't specify this]
|
90 |
+
L_mean: float = None # [automatically set; don't specify this]
|
91 |
+
L_tail: float = None # [automatically set; don't specify this]
|
92 |
+
vacc_mean: float = None # [automatically set; don't specify this]
|
93 |
+
vacc_tail: float = None # [automatically set; don't specify this]
|
94 |
+
vL_mean: float = None # [automatically set; don't specify this]
|
95 |
+
vL_tail: float = None # [automatically set; don't specify this]
|
96 |
+
grad_norm: float = None # [automatically set; don't specify this]
|
97 |
+
cur_lr: float = None # [automatically set; don't specify this]
|
98 |
+
cur_wd: float = None # [automatically set; don't specify this]
|
99 |
+
cur_it: str = '' # [automatically set; don't specify this]
|
100 |
+
cur_ep: str = '' # [automatically set; don't specify this]
|
101 |
+
remain_time: str = '' # [automatically set; don't specify this]
|
102 |
+
finish_time: str = '' # [automatically set; don't specify this]
|
103 |
+
|
104 |
+
# environment
|
105 |
+
local_out_dir_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # [automatically set; don't specify this]
|
106 |
+
tb_log_dir_path: str = '...tb-...' # [automatically set; don't specify this]
|
107 |
+
log_txt_path: str = '...' # [automatically set; don't specify this]
|
108 |
+
last_ckpt_path: str = '...' # [automatically set; don't specify this]
|
109 |
+
|
110 |
+
tf32: bool = True # whether to use TensorFloat32
|
111 |
+
device: str = 'cpu' # [automatically set; don't specify this]
|
112 |
+
seed: int = None # seed
|
113 |
+
def seed_everything(self, benchmark: bool):
|
114 |
+
torch.backends.cudnn.enabled = True
|
115 |
+
torch.backends.cudnn.benchmark = benchmark
|
116 |
+
if self.seed is None:
|
117 |
+
torch.backends.cudnn.deterministic = False
|
118 |
+
else:
|
119 |
+
torch.backends.cudnn.deterministic = True
|
120 |
+
seed = self.seed * dist.get_world_size() + dist.get_rank()
|
121 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
122 |
+
random.seed(seed)
|
123 |
+
np.random.seed(seed)
|
124 |
+
torch.manual_seed(seed)
|
125 |
+
if torch.cuda.is_available():
|
126 |
+
torch.cuda.manual_seed(seed)
|
127 |
+
torch.cuda.manual_seed_all(seed)
|
128 |
+
same_seed_for_all_ranks: int = 0 # this is only for distributed sampler
|
129 |
+
def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation
|
130 |
+
if self.seed is None:
|
131 |
+
return None
|
132 |
+
g = torch.Generator()
|
133 |
+
g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank())
|
134 |
+
return g
|
135 |
+
|
136 |
+
local_debug: bool = 'KEVIN_LOCAL' in os.environ
|
137 |
+
dbg_nan: bool = False # 'KEVIN_LOCAL' in os.environ
|
138 |
+
|
139 |
+
def compile_model(self, m, fast):
|
140 |
+
if fast == 0 or self.local_debug:
|
141 |
+
return m
|
142 |
+
return torch.compile(m, mode={
|
143 |
+
1: 'reduce-overhead',
|
144 |
+
2: 'max-autotune',
|
145 |
+
3: 'default',
|
146 |
+
}[fast]) if hasattr(torch, 'compile') else m
|
147 |
+
|
148 |
+
def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
|
149 |
+
d = (OrderedDict if key_ordered else dict)()
|
150 |
+
# self.as_dict() would contain methods, but we only need variables
|
151 |
+
for k in self.class_variables.keys():
|
152 |
+
if k not in {'device'}: # these are not serializable
|
153 |
+
d[k] = getattr(self, k)
|
154 |
+
return d
|
155 |
+
|
156 |
+
def load_state_dict(self, d: Union[OrderedDict, dict, str]):
|
157 |
+
if isinstance(d, str): # for compatibility with old version
|
158 |
+
d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l]))
|
159 |
+
for k in d.keys():
|
160 |
+
try:
|
161 |
+
setattr(self, k, d[k])
|
162 |
+
except Exception as e:
|
163 |
+
print(f'k={k}, v={d[k]}')
|
164 |
+
raise e
|
165 |
+
|
166 |
+
@staticmethod
|
167 |
+
def set_tf32(tf32: bool):
|
168 |
+
if torch.cuda.is_available():
|
169 |
+
torch.backends.cudnn.allow_tf32 = bool(tf32)
|
170 |
+
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
|
171 |
+
if hasattr(torch, 'set_float32_matmul_precision'):
|
172 |
+
torch.set_float32_matmul_precision('high' if tf32 else 'highest')
|
173 |
+
print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
|
174 |
+
print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
|
175 |
+
print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
|
176 |
+
|
177 |
+
def dump_log(self):
|
178 |
+
if not dist.is_local_master():
|
179 |
+
return
|
180 |
+
if '1/' in self.cur_ep: # first time to dump log
|
181 |
+
with open(self.log_txt_path, 'w') as fp:
|
182 |
+
json.dump({'is_master': dist.is_master(), 'name': self.exp_name, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch, 'tb_log_dir_path': self.tb_log_dir_path}, fp, indent=0)
|
183 |
+
fp.write('\n')
|
184 |
+
|
185 |
+
log_dict = {}
|
186 |
+
for k, v in {
|
187 |
+
'it': self.cur_it, 'ep': self.cur_ep,
|
188 |
+
'lr': self.cur_lr, 'wd': self.cur_wd, 'grad_norm': self.grad_norm,
|
189 |
+
'L_mean': self.L_mean, 'L_tail': self.L_tail, 'acc_mean': self.acc_mean, 'acc_tail': self.acc_tail,
|
190 |
+
'vL_mean': self.vL_mean, 'vL_tail': self.vL_tail, 'vacc_mean': self.vacc_mean, 'vacc_tail': self.vacc_tail,
|
191 |
+
'remain_time': self.remain_time, 'finish_time': self.finish_time,
|
192 |
+
}.items():
|
193 |
+
if hasattr(v, 'item'): v = v.item()
|
194 |
+
log_dict[k] = v
|
195 |
+
with open(self.log_txt_path, 'a') as fp:
|
196 |
+
fp.write(f'{log_dict}\n')
|
197 |
+
|
198 |
+
def __str__(self):
|
199 |
+
s = []
|
200 |
+
for k in self.class_variables.keys():
|
201 |
+
if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
|
202 |
+
s.append(f' {k:20s}: {getattr(self, k)}')
|
203 |
+
s = '\n'.join(s)
|
204 |
+
return f'{{\n{s}\n}}\n'
|
205 |
+
|
206 |
+
|
207 |
+
def init_dist_and_get_args():
|
208 |
+
for i in range(len(sys.argv)):
|
209 |
+
if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
|
210 |
+
del sys.argv[i]
|
211 |
+
break
|
212 |
+
args = Args(explicit_bool=True).parse_args(known_only=True)
|
213 |
+
if args.local_debug:
|
214 |
+
args.pn = '1_2_3'
|
215 |
+
args.seed = 1
|
216 |
+
args.aln = 1e-2
|
217 |
+
args.alng = 1e-5
|
218 |
+
args.saln = False
|
219 |
+
args.afuse = False
|
220 |
+
args.pg = 0.8
|
221 |
+
args.pg0 = 1
|
222 |
+
else:
|
223 |
+
if args.data_path == '/path/to/imagenet':
|
224 |
+
raise ValueError(f'{"*"*40} please specify --data_path=/path/to/imagenet {"*"*40}')
|
225 |
+
|
226 |
+
# warn args.extra_args
|
227 |
+
if len(args.extra_args) > 0:
|
228 |
+
print(f'======================================================================================')
|
229 |
+
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
|
230 |
+
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
|
231 |
+
print(f'======================================================================================\n\n')
|
232 |
+
|
233 |
+
# init torch distributed
|
234 |
+
from utils import misc
|
235 |
+
os.makedirs(args.local_out_dir_path, exist_ok=True)
|
236 |
+
misc.init_distributed_mode(local_out_path=args.local_out_dir_path, timeout=30)
|
237 |
+
|
238 |
+
# set env
|
239 |
+
args.set_tf32(args.tf32)
|
240 |
+
args.seed_everything(benchmark=args.pg == 0)
|
241 |
+
|
242 |
+
# update args: data loading
|
243 |
+
args.device = dist.get_device()
|
244 |
+
if args.pn == '256':
|
245 |
+
args.pn = '1_2_3_4_5_6_8_10_13_16'
|
246 |
+
elif args.pn == '512':
|
247 |
+
args.pn = '1_2_3_4_6_9_13_18_24_32'
|
248 |
+
elif args.pn == '1024':
|
249 |
+
args.pn = '1_2_3_4_5_7_9_12_16_21_27_36_48_64'
|
250 |
+
args.patch_nums = tuple(map(int, args.pn.replace('-', '_').split('_')))
|
251 |
+
args.resos = tuple(pn * args.patch_size for pn in args.patch_nums)
|
252 |
+
args.data_load_reso = max(args.resos)
|
253 |
+
|
254 |
+
# update args: bs and lr
|
255 |
+
bs_per_gpu = round(args.bs / args.ac / dist.get_world_size())
|
256 |
+
args.batch_size = bs_per_gpu
|
257 |
+
args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size()
|
258 |
+
args.workers = min(max(0, args.workers), args.batch_size)
|
259 |
+
|
260 |
+
args.tlr = args.ac * args.tblr * args.glb_batch_size / 256
|
261 |
+
args.twde = args.twde or args.twd
|
262 |
+
|
263 |
+
if args.wp == 0:
|
264 |
+
args.wp = args.ep * 1/50
|
265 |
+
|
266 |
+
# update args: progressive training
|
267 |
+
if args.pgwp == 0:
|
268 |
+
args.pgwp = args.ep * 1/300
|
269 |
+
if args.pg > 0:
|
270 |
+
args.sche = f'lin{args.pg:g}'
|
271 |
+
|
272 |
+
# update args: paths
|
273 |
+
args.log_txt_path = os.path.join(args.local_out_dir_path, 'log.txt')
|
274 |
+
args.last_ckpt_path = os.path.join(args.local_out_dir_path, f'ar-ckpt-last.pth')
|
275 |
+
_reg_valid_name = re.compile(r'[^\w\-+,.]')
|
276 |
+
tb_name = _reg_valid_name.sub(
|
277 |
+
'_',
|
278 |
+
f'tb-VARd{args.depth}'
|
279 |
+
f'__pn{args.pn}'
|
280 |
+
f'__b{args.bs}ep{args.ep}{args.opt[:4]}lr{args.tblr:g}wd{args.twd:g}'
|
281 |
+
)
|
282 |
+
args.tb_log_dir_path = os.path.join(args.local_out_dir_path, tb_name)
|
283 |
+
|
284 |
+
return args
|
VAR/utils/data.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
|
3 |
+
import PIL.Image as PImage
|
4 |
+
from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS
|
5 |
+
from torchvision.transforms import InterpolationMode, transforms
|
6 |
+
|
7 |
+
|
8 |
+
def normalize_01_into_pm1(x): # normalize x from [0, 1] to [-1, 1] by (x*2) - 1
|
9 |
+
return x.add(x).add_(-1)
|
10 |
+
|
11 |
+
|
12 |
+
def build_dataset(
|
13 |
+
data_path: str, final_reso: int,
|
14 |
+
hflip=False, mid_reso=1.125,
|
15 |
+
):
|
16 |
+
# build augmentations
|
17 |
+
mid_reso = round(mid_reso * final_reso) # first resize to mid_reso, then crop to final_reso
|
18 |
+
train_aug, val_aug = [
|
19 |
+
transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso
|
20 |
+
transforms.RandomCrop((final_reso, final_reso)),
|
21 |
+
transforms.ToTensor(), normalize_01_into_pm1,
|
22 |
+
], [
|
23 |
+
transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso
|
24 |
+
transforms.CenterCrop((final_reso, final_reso)),
|
25 |
+
transforms.ToTensor(), normalize_01_into_pm1,
|
26 |
+
]
|
27 |
+
if hflip: train_aug.insert(0, transforms.RandomHorizontalFlip())
|
28 |
+
train_aug, val_aug = transforms.Compose(train_aug), transforms.Compose(val_aug)
|
29 |
+
|
30 |
+
# build dataset
|
31 |
+
train_set = DatasetFolder(root=osp.join(data_path, 'train'), loader=pil_loader, extensions=IMG_EXTENSIONS, transform=train_aug)
|
32 |
+
val_set = DatasetFolder(root=osp.join(data_path, 'val'), loader=pil_loader, extensions=IMG_EXTENSIONS, transform=val_aug)
|
33 |
+
num_classes = 1000
|
34 |
+
print(f'[Dataset] {len(train_set)=}, {len(val_set)=}, {num_classes=}')
|
35 |
+
print_aug(train_aug, '[train]')
|
36 |
+
print_aug(val_aug, '[val]')
|
37 |
+
|
38 |
+
return num_classes, train_set, val_set
|
39 |
+
|
40 |
+
|
41 |
+
def pil_loader(path):
|
42 |
+
with open(path, 'rb') as f:
|
43 |
+
img: PImage.Image = PImage.open(f).convert('RGB')
|
44 |
+
return img
|
45 |
+
|
46 |
+
|
47 |
+
def print_aug(transform, label):
|
48 |
+
print(f'Transform {label} = ')
|
49 |
+
if hasattr(transform, 'transforms'):
|
50 |
+
for t in transform.transforms:
|
51 |
+
print(t)
|
52 |
+
else:
|
53 |
+
print(transform)
|
54 |
+
print('---------------------------\n')
|
VAR/utils/data_sampler.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch.utils.data.sampler import Sampler
|
4 |
+
|
5 |
+
|
6 |
+
class EvalDistributedSampler(Sampler):
|
7 |
+
def __init__(self, dataset, num_replicas, rank):
|
8 |
+
seps = np.linspace(0, len(dataset), num_replicas+1, dtype=int)
|
9 |
+
beg, end = seps[:-1], seps[1:]
|
10 |
+
beg, end = beg[rank], end[rank]
|
11 |
+
self.indices = tuple(range(beg, end))
|
12 |
+
|
13 |
+
def __iter__(self):
|
14 |
+
return iter(self.indices)
|
15 |
+
|
16 |
+
def __len__(self) -> int:
|
17 |
+
return len(self.indices)
|
18 |
+
|
19 |
+
|
20 |
+
class InfiniteBatchSampler(Sampler):
|
21 |
+
def __init__(self, dataset_len, batch_size, seed_for_all_rank=0, fill_last=False, shuffle=True, drop_last=False, start_ep=0, start_it=0):
|
22 |
+
self.dataset_len = dataset_len
|
23 |
+
self.batch_size = batch_size
|
24 |
+
self.iters_per_ep = dataset_len // batch_size if drop_last else (dataset_len + batch_size - 1) // batch_size
|
25 |
+
self.max_p = self.iters_per_ep * batch_size
|
26 |
+
self.fill_last = fill_last
|
27 |
+
self.shuffle = shuffle
|
28 |
+
self.epoch = start_ep
|
29 |
+
self.same_seed_for_all_ranks = seed_for_all_rank
|
30 |
+
self.indices = self.gener_indices()
|
31 |
+
self.start_ep, self.start_it = start_ep, start_it
|
32 |
+
|
33 |
+
def gener_indices(self):
|
34 |
+
if self.shuffle:
|
35 |
+
g = torch.Generator()
|
36 |
+
g.manual_seed(self.epoch + self.same_seed_for_all_ranks)
|
37 |
+
indices = torch.randperm(self.dataset_len, generator=g).numpy()
|
38 |
+
else:
|
39 |
+
indices = torch.arange(self.dataset_len).numpy()
|
40 |
+
|
41 |
+
tails = self.batch_size - (self.dataset_len % self.batch_size)
|
42 |
+
if tails != self.batch_size and self.fill_last:
|
43 |
+
tails = indices[:tails]
|
44 |
+
np.random.shuffle(indices)
|
45 |
+
indices = np.concatenate((indices, tails))
|
46 |
+
|
47 |
+
# built-in list/tuple is faster than np.ndarray (when collating the data via a for-loop)
|
48 |
+
# noinspection PyTypeChecker
|
49 |
+
return tuple(indices.tolist())
|
50 |
+
|
51 |
+
def __iter__(self):
|
52 |
+
self.epoch = self.start_ep
|
53 |
+
while True:
|
54 |
+
self.epoch += 1
|
55 |
+
p = (self.start_it * self.batch_size) if self.epoch == self.start_ep else 0
|
56 |
+
while p < self.max_p:
|
57 |
+
q = p + self.batch_size
|
58 |
+
yield self.indices[p:q]
|
59 |
+
p = q
|
60 |
+
if self.shuffle:
|
61 |
+
self.indices = self.gener_indices()
|
62 |
+
|
63 |
+
def __len__(self):
|
64 |
+
return self.iters_per_ep
|
65 |
+
|
66 |
+
|
67 |
+
class DistInfiniteBatchSampler(InfiniteBatchSampler):
|
68 |
+
def __init__(self, world_size, rank, dataset_len, glb_batch_size, same_seed_for_all_ranks=0, repeated_aug=0, fill_last=False, shuffle=True, start_ep=0, start_it=0):
|
69 |
+
assert glb_batch_size % world_size == 0
|
70 |
+
self.world_size, self.rank = world_size, rank
|
71 |
+
self.dataset_len = dataset_len
|
72 |
+
self.glb_batch_size = glb_batch_size
|
73 |
+
self.batch_size = glb_batch_size // world_size
|
74 |
+
|
75 |
+
self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size
|
76 |
+
self.fill_last = fill_last
|
77 |
+
self.shuffle = shuffle
|
78 |
+
self.repeated_aug = repeated_aug
|
79 |
+
self.epoch = start_ep
|
80 |
+
self.same_seed_for_all_ranks = same_seed_for_all_ranks
|
81 |
+
self.indices = self.gener_indices()
|
82 |
+
self.start_ep, self.start_it = start_ep, start_it
|
83 |
+
|
84 |
+
def gener_indices(self):
|
85 |
+
global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0
|
86 |
+
# print(f'global_max_p = iters_per_ep({self.iters_per_ep}) * glb_batch_size({self.glb_batch_size}) = {global_max_p}')
|
87 |
+
if self.shuffle:
|
88 |
+
g = torch.Generator()
|
89 |
+
g.manual_seed(self.epoch + self.same_seed_for_all_ranks)
|
90 |
+
global_indices = torch.randperm(self.dataset_len, generator=g)
|
91 |
+
if self.repeated_aug > 1:
|
92 |
+
global_indices = global_indices[:(self.dataset_len + self.repeated_aug - 1) // self.repeated_aug].repeat_interleave(self.repeated_aug, dim=0)[:global_max_p]
|
93 |
+
else:
|
94 |
+
global_indices = torch.arange(self.dataset_len)
|
95 |
+
filling = global_max_p - global_indices.shape[0]
|
96 |
+
if filling > 0 and self.fill_last:
|
97 |
+
global_indices = torch.cat((global_indices, global_indices[:filling]))
|
98 |
+
# global_indices = tuple(global_indices.numpy().tolist())
|
99 |
+
|
100 |
+
seps = torch.linspace(0, global_indices.shape[0], self.world_size + 1, dtype=torch.int)
|
101 |
+
local_indices = global_indices[seps[self.rank].item():seps[self.rank + 1].item()].tolist()
|
102 |
+
self.max_p = len(local_indices)
|
103 |
+
return local_indices
|
VAR/utils/lr_control.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from pprint import pformat
|
3 |
+
from typing import Tuple, List, Dict, Union
|
4 |
+
|
5 |
+
import torch.nn
|
6 |
+
|
7 |
+
import dist
|
8 |
+
|
9 |
+
|
10 |
+
def lr_wd_annealing(sche_type: str, optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it, wp0=0.005, wpe=0.001):
|
11 |
+
"""Decay the learning rate with half-cycle cosine after warmup"""
|
12 |
+
wp_it = round(wp_it)
|
13 |
+
|
14 |
+
if cur_it < wp_it:
|
15 |
+
cur_lr = wp0 + (1-wp0) * cur_it / wp_it
|
16 |
+
else:
|
17 |
+
pasd = (cur_it - wp_it) / (max_it-1 - wp_it) # [0, 1]
|
18 |
+
rest = 1 - pasd # [1, 0]
|
19 |
+
if sche_type == 'cos':
|
20 |
+
cur_lr = wpe + (1-wpe) * (0.5 + 0.5 * math.cos(math.pi * pasd))
|
21 |
+
elif sche_type == 'lin':
|
22 |
+
T = 0.15; max_rest = 1-T
|
23 |
+
if pasd < T: cur_lr = 1
|
24 |
+
else: cur_lr = wpe + (1-wpe) * rest / max_rest # 1 to wpe
|
25 |
+
elif sche_type == 'lin0':
|
26 |
+
T = 0.05; max_rest = 1-T
|
27 |
+
if pasd < T: cur_lr = 1
|
28 |
+
else: cur_lr = wpe + (1-wpe) * rest / max_rest
|
29 |
+
elif sche_type == 'lin00':
|
30 |
+
cur_lr = wpe + (1-wpe) * rest
|
31 |
+
elif sche_type.startswith('lin'):
|
32 |
+
T = float(sche_type[3:]); max_rest = 1-T
|
33 |
+
wpe_mid = wpe + (1-wpe) * max_rest
|
34 |
+
wpe_mid = (1 + wpe_mid) / 2
|
35 |
+
if pasd < T: cur_lr = 1 + (wpe_mid-1) * pasd / T
|
36 |
+
else: cur_lr = wpe + (wpe_mid-wpe) * rest / max_rest
|
37 |
+
elif sche_type == 'exp':
|
38 |
+
T = 0.15; max_rest = 1-T
|
39 |
+
if pasd < T: cur_lr = 1
|
40 |
+
else:
|
41 |
+
expo = (pasd-T) / max_rest * math.log(wpe)
|
42 |
+
cur_lr = math.exp(expo)
|
43 |
+
else:
|
44 |
+
raise NotImplementedError(f'unknown sche_type {sche_type}')
|
45 |
+
|
46 |
+
cur_lr *= peak_lr
|
47 |
+
pasd = cur_it / (max_it-1)
|
48 |
+
cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * pasd))
|
49 |
+
|
50 |
+
inf = 1e6
|
51 |
+
min_lr, max_lr = inf, -1
|
52 |
+
min_wd, max_wd = inf, -1
|
53 |
+
for param_group in optimizer.param_groups:
|
54 |
+
param_group['lr'] = cur_lr * param_group.get('lr_sc', 1) # 'lr_sc' could be assigned
|
55 |
+
max_lr = max(max_lr, param_group['lr'])
|
56 |
+
min_lr = min(min_lr, param_group['lr'])
|
57 |
+
|
58 |
+
param_group['weight_decay'] = cur_wd * param_group.get('wd_sc', 1)
|
59 |
+
max_wd = max(max_wd, param_group['weight_decay'])
|
60 |
+
if param_group['weight_decay'] > 0:
|
61 |
+
min_wd = min(min_wd, param_group['weight_decay'])
|
62 |
+
|
63 |
+
if min_lr == inf: min_lr = -1
|
64 |
+
if min_wd == inf: min_wd = -1
|
65 |
+
return min_lr, max_lr, min_wd, max_wd
|
66 |
+
|
67 |
+
|
68 |
+
def filter_params(model, nowd_keys=()) -> Tuple[
|
69 |
+
List[str], List[torch.nn.Parameter], List[Dict[str, Union[torch.nn.Parameter, float]]]
|
70 |
+
]:
|
71 |
+
para_groups, para_groups_dbg = {}, {}
|
72 |
+
names, paras = [], []
|
73 |
+
names_no_grad = []
|
74 |
+
count, numel = 0, 0
|
75 |
+
for name, para in model.named_parameters():
|
76 |
+
name = name.replace('_fsdp_wrapped_module.', '')
|
77 |
+
if not para.requires_grad:
|
78 |
+
names_no_grad.append(name)
|
79 |
+
continue # frozen weights
|
80 |
+
count += 1
|
81 |
+
numel += para.numel()
|
82 |
+
names.append(name)
|
83 |
+
paras.append(para)
|
84 |
+
|
85 |
+
if para.ndim == 1 or name.endswith('bias') or any(k in name for k in nowd_keys):
|
86 |
+
cur_wd_sc, group_name = 0., 'ND'
|
87 |
+
else:
|
88 |
+
cur_wd_sc, group_name = 1., 'D'
|
89 |
+
cur_lr_sc = 1.
|
90 |
+
if group_name not in para_groups:
|
91 |
+
para_groups[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
|
92 |
+
para_groups_dbg[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
|
93 |
+
para_groups[group_name]['params'].append(para)
|
94 |
+
para_groups_dbg[group_name]['params'].append(name)
|
95 |
+
|
96 |
+
for g in para_groups_dbg.values():
|
97 |
+
g['params'] = pformat(', '.join(g['params']), width=200)
|
98 |
+
|
99 |
+
print(f'[get_param_groups] param_groups = \n{pformat(para_groups_dbg, indent=2, width=240)}\n')
|
100 |
+
|
101 |
+
for rk in range(dist.get_world_size()):
|
102 |
+
dist.barrier()
|
103 |
+
if dist.get_rank() == rk:
|
104 |
+
print(f'[get_param_groups][rank{dist.get_rank()}] {type(model).__name__=} {count=}, {numel=}', flush=True, force=True)
|
105 |
+
print('')
|
106 |
+
|
107 |
+
assert len(names_no_grad) == 0, f'[get_param_groups] names_no_grad = \n{pformat(names_no_grad, indent=2, width=240)}\n'
|
108 |
+
return names, paras, list(para_groups.values())
|
VAR/utils/misc.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import functools
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import subprocess
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
from collections import defaultdict, deque
|
9 |
+
from typing import Iterator, List, Tuple
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import pytz
|
13 |
+
import torch
|
14 |
+
import torch.distributed as tdist
|
15 |
+
|
16 |
+
import dist
|
17 |
+
from utils import arg_util
|
18 |
+
|
19 |
+
os_system = functools.partial(subprocess.call, shell=True)
|
20 |
+
def echo(info):
|
21 |
+
os_system(f'echo "[$(date "+%m-%d-%H:%M:%S")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}"')
|
22 |
+
def os_system_get_stdout(cmd):
|
23 |
+
return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')
|
24 |
+
def os_system_get_stdout_stderr(cmd):
|
25 |
+
cnt = 0
|
26 |
+
while True:
|
27 |
+
try:
|
28 |
+
sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=30)
|
29 |
+
except subprocess.TimeoutExpired:
|
30 |
+
cnt += 1
|
31 |
+
print(f'[fetch free_port file] timeout cnt={cnt}')
|
32 |
+
else:
|
33 |
+
return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8')
|
34 |
+
|
35 |
+
|
36 |
+
def time_str(fmt='[%m-%d %H:%M:%S]'):
|
37 |
+
return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt)
|
38 |
+
|
39 |
+
|
40 |
+
def init_distributed_mode(local_out_path, only_sync_master=False, timeout=30):
|
41 |
+
try:
|
42 |
+
dist.initialize(fork=False, timeout=timeout)
|
43 |
+
dist.barrier()
|
44 |
+
except RuntimeError:
|
45 |
+
print(f'{">"*75} NCCL Error {"<"*75}', flush=True)
|
46 |
+
time.sleep(10)
|
47 |
+
|
48 |
+
if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True)
|
49 |
+
_change_builtin_print(dist.is_local_master())
|
50 |
+
if (dist.is_master() if only_sync_master else dist.is_local_master()) and local_out_path is not None and len(local_out_path):
|
51 |
+
sys.stdout, sys.stderr = SyncPrint(local_out_path, sync_stdout=True), SyncPrint(local_out_path, sync_stdout=False)
|
52 |
+
|
53 |
+
|
54 |
+
def _change_builtin_print(is_master):
|
55 |
+
import builtins as __builtin__
|
56 |
+
|
57 |
+
builtin_print = __builtin__.print
|
58 |
+
if type(builtin_print) != type(open):
|
59 |
+
return
|
60 |
+
|
61 |
+
def prt(*args, **kwargs):
|
62 |
+
force = kwargs.pop('force', False)
|
63 |
+
clean = kwargs.pop('clean', False)
|
64 |
+
deeper = kwargs.pop('deeper', False)
|
65 |
+
if is_master or force:
|
66 |
+
if not clean:
|
67 |
+
f_back = sys._getframe().f_back
|
68 |
+
if deeper and f_back.f_back is not None:
|
69 |
+
f_back = f_back.f_back
|
70 |
+
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
|
71 |
+
builtin_print(f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs)
|
72 |
+
else:
|
73 |
+
builtin_print(*args, **kwargs)
|
74 |
+
|
75 |
+
__builtin__.print = prt
|
76 |
+
|
77 |
+
|
78 |
+
class SyncPrint(object):
|
79 |
+
def __init__(self, local_output_dir, sync_stdout=True):
|
80 |
+
self.sync_stdout = sync_stdout
|
81 |
+
self.terminal_stream = sys.stdout if sync_stdout else sys.stderr
|
82 |
+
fname = os.path.join(local_output_dir, 'stdout.txt' if sync_stdout else 'stderr.txt')
|
83 |
+
existing = os.path.exists(fname)
|
84 |
+
self.file_stream = open(fname, 'a')
|
85 |
+
if existing:
|
86 |
+
self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str()} ' + '='*55 + '\n')
|
87 |
+
self.file_stream.flush()
|
88 |
+
self.enabled = True
|
89 |
+
|
90 |
+
def write(self, message):
|
91 |
+
self.terminal_stream.write(message)
|
92 |
+
self.file_stream.write(message)
|
93 |
+
|
94 |
+
def flush(self):
|
95 |
+
self.terminal_stream.flush()
|
96 |
+
self.file_stream.flush()
|
97 |
+
|
98 |
+
def close(self):
|
99 |
+
if not self.enabled:
|
100 |
+
return
|
101 |
+
self.enabled = False
|
102 |
+
self.file_stream.flush()
|
103 |
+
self.file_stream.close()
|
104 |
+
if self.sync_stdout:
|
105 |
+
sys.stdout = self.terminal_stream
|
106 |
+
sys.stdout.flush()
|
107 |
+
else:
|
108 |
+
sys.stderr = self.terminal_stream
|
109 |
+
sys.stderr.flush()
|
110 |
+
|
111 |
+
def __del__(self):
|
112 |
+
self.close()
|
113 |
+
|
114 |
+
|
115 |
+
class DistLogger(object):
|
116 |
+
def __init__(self, lg, verbose):
|
117 |
+
self._lg, self._verbose = lg, verbose
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def do_nothing(*args, **kwargs):
|
121 |
+
pass
|
122 |
+
|
123 |
+
def __getattr__(self, attr: str):
|
124 |
+
return getattr(self._lg, attr) if self._verbose else DistLogger.do_nothing
|
125 |
+
|
126 |
+
|
127 |
+
class TensorboardLogger(object):
|
128 |
+
def __init__(self, log_dir, filename_suffix):
|
129 |
+
try: import tensorflow_io as tfio
|
130 |
+
except: pass
|
131 |
+
from torch.utils.tensorboard import SummaryWriter
|
132 |
+
self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=filename_suffix)
|
133 |
+
self.step = 0
|
134 |
+
|
135 |
+
def set_step(self, step=None):
|
136 |
+
if step is not None:
|
137 |
+
self.step = step
|
138 |
+
else:
|
139 |
+
self.step += 1
|
140 |
+
|
141 |
+
def update(self, head='scalar', step=None, **kwargs):
|
142 |
+
for k, v in kwargs.items():
|
143 |
+
if v is None:
|
144 |
+
continue
|
145 |
+
# assert isinstance(v, (float, int)), type(v)
|
146 |
+
if step is None: # iter wise
|
147 |
+
it = self.step
|
148 |
+
if it == 0 or (it + 1) % 500 == 0:
|
149 |
+
if hasattr(v, 'item'): v = v.item()
|
150 |
+
self.writer.add_scalar(f'{head}/{k}', v, it)
|
151 |
+
else: # epoch wise
|
152 |
+
if hasattr(v, 'item'): v = v.item()
|
153 |
+
self.writer.add_scalar(f'{head}/{k}', v, step)
|
154 |
+
|
155 |
+
def log_tensor_as_distri(self, tag, tensor1d, step=None):
|
156 |
+
if step is None: # iter wise
|
157 |
+
step = self.step
|
158 |
+
loggable = step == 0 or (step + 1) % 500 == 0
|
159 |
+
else: # epoch wise
|
160 |
+
loggable = True
|
161 |
+
if loggable:
|
162 |
+
try:
|
163 |
+
self.writer.add_histogram(tag=tag, values=tensor1d, global_step=step)
|
164 |
+
except Exception as e:
|
165 |
+
print(f'[log_tensor_as_distri writer.add_histogram failed]: {e}')
|
166 |
+
|
167 |
+
def log_image(self, tag, img_chw, step=None):
|
168 |
+
if step is None: # iter wise
|
169 |
+
step = self.step
|
170 |
+
loggable = step == 0 or (step + 1) % 500 == 0
|
171 |
+
else: # epoch wise
|
172 |
+
loggable = True
|
173 |
+
if loggable:
|
174 |
+
self.writer.add_image(tag, img_chw, step, dataformats='CHW')
|
175 |
+
|
176 |
+
def flush(self):
|
177 |
+
self.writer.flush()
|
178 |
+
|
179 |
+
def close(self):
|
180 |
+
self.writer.close()
|
181 |
+
|
182 |
+
|
183 |
+
class SmoothedValue(object):
|
184 |
+
"""Track a series of values and provide access to smoothed values over a
|
185 |
+
window or the global series average.
|
186 |
+
"""
|
187 |
+
|
188 |
+
def __init__(self, window_size=30, fmt=None):
|
189 |
+
if fmt is None:
|
190 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
191 |
+
self.deque = deque(maxlen=window_size)
|
192 |
+
self.total = 0.0
|
193 |
+
self.count = 0
|
194 |
+
self.fmt = fmt
|
195 |
+
|
196 |
+
def update(self, value, n=1):
|
197 |
+
self.deque.append(value)
|
198 |
+
self.count += n
|
199 |
+
self.total += value * n
|
200 |
+
|
201 |
+
def synchronize_between_processes(self):
|
202 |
+
"""
|
203 |
+
Warning: does not synchronize the deque!
|
204 |
+
"""
|
205 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
206 |
+
tdist.barrier()
|
207 |
+
tdist.all_reduce(t)
|
208 |
+
t = t.tolist()
|
209 |
+
self.count = int(t[0])
|
210 |
+
self.total = t[1]
|
211 |
+
|
212 |
+
@property
|
213 |
+
def median(self):
|
214 |
+
return np.median(self.deque) if len(self.deque) else 0
|
215 |
+
|
216 |
+
@property
|
217 |
+
def avg(self):
|
218 |
+
return sum(self.deque) / (len(self.deque) or 1)
|
219 |
+
|
220 |
+
@property
|
221 |
+
def global_avg(self):
|
222 |
+
return self.total / (self.count or 1)
|
223 |
+
|
224 |
+
@property
|
225 |
+
def max(self):
|
226 |
+
return max(self.deque)
|
227 |
+
|
228 |
+
@property
|
229 |
+
def value(self):
|
230 |
+
return self.deque[-1] if len(self.deque) else 0
|
231 |
+
|
232 |
+
def time_preds(self, counts) -> Tuple[float, str, str]:
|
233 |
+
remain_secs = counts * self.median
|
234 |
+
return remain_secs, str(datetime.timedelta(seconds=round(remain_secs))), time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() + remain_secs))
|
235 |
+
|
236 |
+
def __str__(self):
|
237 |
+
return self.fmt.format(
|
238 |
+
median=self.median,
|
239 |
+
avg=self.avg,
|
240 |
+
global_avg=self.global_avg,
|
241 |
+
max=self.max,
|
242 |
+
value=self.value)
|
243 |
+
|
244 |
+
|
245 |
+
class MetricLogger(object):
|
246 |
+
def __init__(self, delimiter=' '):
|
247 |
+
self.meters = defaultdict(SmoothedValue)
|
248 |
+
self.delimiter = delimiter
|
249 |
+
self.iter_end_t = time.time()
|
250 |
+
self.log_iters = []
|
251 |
+
|
252 |
+
def update(self, **kwargs):
|
253 |
+
for k, v in kwargs.items():
|
254 |
+
if v is None:
|
255 |
+
continue
|
256 |
+
if hasattr(v, 'item'): v = v.item()
|
257 |
+
# assert isinstance(v, (float, int)), type(v)
|
258 |
+
assert isinstance(v, (float, int))
|
259 |
+
self.meters[k].update(v)
|
260 |
+
|
261 |
+
def __getattr__(self, attr):
|
262 |
+
if attr in self.meters:
|
263 |
+
return self.meters[attr]
|
264 |
+
if attr in self.__dict__:
|
265 |
+
return self.__dict__[attr]
|
266 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
267 |
+
type(self).__name__, attr))
|
268 |
+
|
269 |
+
def __str__(self):
|
270 |
+
loss_str = []
|
271 |
+
for name, meter in self.meters.items():
|
272 |
+
if len(meter.deque):
|
273 |
+
loss_str.append(
|
274 |
+
"{}: {}".format(name, str(meter))
|
275 |
+
)
|
276 |
+
return self.delimiter.join(loss_str)
|
277 |
+
|
278 |
+
def synchronize_between_processes(self):
|
279 |
+
for meter in self.meters.values():
|
280 |
+
meter.synchronize_between_processes()
|
281 |
+
|
282 |
+
def add_meter(self, name, meter):
|
283 |
+
self.meters[name] = meter
|
284 |
+
|
285 |
+
def log_every(self, start_it, max_iters, itrt, print_freq, header=None):
|
286 |
+
self.log_iters = set(np.linspace(0, max_iters-1, print_freq, dtype=int).tolist())
|
287 |
+
self.log_iters.add(start_it)
|
288 |
+
if not header:
|
289 |
+
header = ''
|
290 |
+
start_time = time.time()
|
291 |
+
self.iter_end_t = time.time()
|
292 |
+
self.iter_time = SmoothedValue(fmt='{avg:.4f}')
|
293 |
+
self.data_time = SmoothedValue(fmt='{avg:.4f}')
|
294 |
+
space_fmt = ':' + str(len(str(max_iters))) + 'd'
|
295 |
+
log_msg = [
|
296 |
+
header,
|
297 |
+
'[{0' + space_fmt + '}/{1}]',
|
298 |
+
'eta: {eta}',
|
299 |
+
'{meters}',
|
300 |
+
'time: {time}',
|
301 |
+
'data: {data}'
|
302 |
+
]
|
303 |
+
log_msg = self.delimiter.join(log_msg)
|
304 |
+
|
305 |
+
if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'):
|
306 |
+
for i in range(start_it, max_iters):
|
307 |
+
obj = next(itrt)
|
308 |
+
self.data_time.update(time.time() - self.iter_end_t)
|
309 |
+
yield i, obj
|
310 |
+
self.iter_time.update(time.time() - self.iter_end_t)
|
311 |
+
if i in self.log_iters:
|
312 |
+
eta_seconds = self.iter_time.global_avg * (max_iters - i)
|
313 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
314 |
+
print(log_msg.format(
|
315 |
+
i, max_iters, eta=eta_string,
|
316 |
+
meters=str(self),
|
317 |
+
time=str(self.iter_time), data=str(self.data_time)), flush=True)
|
318 |
+
self.iter_end_t = time.time()
|
319 |
+
else:
|
320 |
+
if isinstance(itrt, int): itrt = range(itrt)
|
321 |
+
for i, obj in enumerate(itrt):
|
322 |
+
self.data_time.update(time.time() - self.iter_end_t)
|
323 |
+
yield i, obj
|
324 |
+
self.iter_time.update(time.time() - self.iter_end_t)
|
325 |
+
if i in self.log_iters:
|
326 |
+
eta_seconds = self.iter_time.global_avg * (max_iters - i)
|
327 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
328 |
+
print(log_msg.format(
|
329 |
+
i, max_iters, eta=eta_string,
|
330 |
+
meters=str(self),
|
331 |
+
time=str(self.iter_time), data=str(self.data_time)), flush=True)
|
332 |
+
self.iter_end_t = time.time()
|
333 |
+
|
334 |
+
total_time = time.time() - start_time
|
335 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
336 |
+
print('{} Total time: {} ({:.3f} s / it)'.format(
|
337 |
+
header, total_time_str, total_time / max_iters), flush=True)
|
338 |
+
|
339 |
+
|
340 |
+
def glob_with_latest_modified_first(pattern, recursive=False):
|
341 |
+
return sorted(glob.glob(pattern, recursive=recursive), key=os.path.getmtime, reverse=True)
|
342 |
+
|
343 |
+
|
344 |
+
def auto_resume(args: arg_util.Args, pattern='ckpt*.pth') -> Tuple[List[str], int, int, dict, dict]:
|
345 |
+
info = []
|
346 |
+
file = os.path.join(args.local_out_dir_path, pattern)
|
347 |
+
all_ckpt = glob_with_latest_modified_first(file)
|
348 |
+
if len(all_ckpt) == 0:
|
349 |
+
info.append(f'[auto_resume] no ckpt found @ {file}')
|
350 |
+
info.append(f'[auto_resume quit]')
|
351 |
+
return info, 0, 0, {}, {}
|
352 |
+
else:
|
353 |
+
info.append(f'[auto_resume] load ckpt from @ {all_ckpt[0]} ...')
|
354 |
+
ckpt = torch.load(all_ckpt[0], map_location='cpu')
|
355 |
+
ep, it = ckpt['epoch'], ckpt['iter']
|
356 |
+
info.append(f'[auto_resume success] resume from ep{ep}, it{it}')
|
357 |
+
return info, ep, it, ckpt['trainer'], ckpt['args']
|
358 |
+
|
359 |
+
|
360 |
+
def create_npz_from_sample_folder(sample_folder: str):
|
361 |
+
"""
|
362 |
+
Builds a single .npz file from a folder of .png samples. Refer to DiT.
|
363 |
+
"""
|
364 |
+
import os, glob
|
365 |
+
import numpy as np
|
366 |
+
from tqdm import tqdm
|
367 |
+
from PIL import Image
|
368 |
+
|
369 |
+
samples = []
|
370 |
+
pngs = glob.glob(os.path.join(sample_folder, '*.png')) + glob.glob(os.path.join(sample_folder, '*.PNG'))
|
371 |
+
assert len(pngs) == 50_000, f'{len(pngs)} png files found in {sample_folder}, but expected 50,000'
|
372 |
+
for png in tqdm(pngs, desc='Building .npz file from samples (png only)'):
|
373 |
+
with Image.open(png) as sample_pil:
|
374 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
375 |
+
samples.append(sample_np)
|
376 |
+
samples = np.stack(samples)
|
377 |
+
assert samples.shape == (50_000, samples.shape[1], samples.shape[2], 3)
|
378 |
+
npz_path = f'{sample_folder}.npz'
|
379 |
+
np.savez(npz_path, arr_0=samples)
|
380 |
+
print(f'Saved .npz file to {npz_path} [shape={samples.shape}].')
|
381 |
+
return npz_path
|
infrance_example.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from models import VQVAE, build_vae_var
|
3 |
+
from dataset.imagenet_dataset import get_train_transforms
|
4 |
+
from PIL import Image
|
5 |
+
from torchvision import transforms
|
6 |
+
|
7 |
+
|
8 |
+
device = 'mps'
|
9 |
+
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
|
10 |
+
|
11 |
+
vae, var = build_vae_var(
|
12 |
+
V=4096, Cvae=32, ch=160, share_quant_resi=4,
|
13 |
+
device=device, patch_nums=patch_nums,
|
14 |
+
num_classes=1000, depth=16, shared_aln=False,
|
15 |
+
)
|
16 |
+
var_ckpt='var_d16.pth'
|
17 |
+
vae_ckpt='vae_ch160v4096z32.pth'
|
18 |
+
var.load_state_dict(torch.load(var_ckpt, map_location=device), strict=True)
|
19 |
+
vae.load_state_dict(torch.load(vae_ckpt, map_location=device), strict=True)
|
model-step-step=32000.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4222b988d8108a1288367e282654e1e7819dd4d1d34b4b1efecec924d94d718d
|
3 |
+
size 4161643659
|
vae_ch160v4096z32.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7c3ec27ae28a3f87055e83211ea8cc8558bd1985d7b51742d074fb4c2fcf186c
|
3 |
+
size 436075834
|