AmitIsraeli commited on
Commit
fc8623e
1 Parent(s): ded2f46

add checkpoint VAR trained on pops

Browse files
VAR/.gitignore ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.swp
2
+ **/__pycache__/**
3
+ **/.ipynb_checkpoints/**
4
+ .DS_Store
5
+ .idea/*
6
+ .vscode/*
7
+ llava/
8
+ _vis_cached/
9
+ _auto_*
10
+ ckpt/
11
+ log/
12
+ tb*/
13
+ img*/
14
+ local_output*
15
+ *.pth
16
+ *.pth.tar
17
+ *.ckpt
18
+ *.log
19
+ *.txt
20
+ *.ipynb
VAR/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 FoundationVision
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
VAR/README.md ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VAR: a new visual generation method elevates GPT-style models beyond diffusion🚀 & Scaling laws observed📈
2
+
3
+ <div align="center">
4
+
5
+ [![demo platform](https://img.shields.io/badge/Play%20with%20VAR%21-VAR%20demo%20platform-lightblue)](https://var.vision/demo)&nbsp;
6
+ [![arXiv](https://img.shields.io/badge/arXiv%20paper-2404.02905-b31b1b.svg)](https://arxiv.org/abs/2404.02905)&nbsp;
7
+ [![huggingface weights](https://img.shields.io/badge/%F0%9F%A4%97%20Weights-FoundationVision/var-yellow)](https://huggingface.co/FoundationVision/var)&nbsp;
8
+ [![SOTA](https://img.shields.io/badge/State%20of%20the%20Art-Image%20Generation%20on%20ImageNet%20%28AR%29-32B1B4?logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iNjA2IiBoZWlnaHQ9IjYwNiIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIiB4bWxuczp4bGluaz0iaHR0cDovL3d3dy53My5vcmcvMTk5OS94bGluayIgb3ZlcmZsb3c9ImhpZGRlbiI%2BPGRlZnM%2BPGNsaXBQYXRoIGlkPSJjbGlwMCI%2BPHJlY3QgeD0iLTEiIHk9Ii0xIiB3aWR0aD0iNjA2IiBoZWlnaHQ9IjYwNiIvPjwvY2xpcFBhdGg%2BPC9kZWZzPjxnIGNsaXAtcGF0aD0idXJsKCNjbGlwMCkiIHRyYW5zZm9ybT0idHJhbnNsYXRlKDEgMSkiPjxyZWN0IHg9IjUyOSIgeT0iNjYiIHdpZHRoPSI1NiIgaGVpZ2h0PSI0NzMiIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSIxOSIgeT0iNjYiIHdpZHRoPSI1NyIgaGVpZ2h0PSI0NzMiIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSIyNzQiIHk9IjE1MSIgd2lkdGg9IjU3IiBoZWlnaHQ9IjMwMiIgZmlsbD0iIzQ0RjJGNiIvPjxyZWN0IHg9IjEwNCIgeT0iMTUxIiB3aWR0aD0iNTciIGhlaWdodD0iMzAyIiBmaWxsPSIjNDRGMkY2Ii8%2BPHJlY3QgeD0iNDQ0IiB5PSIxNTEiIHdpZHRoPSI1NyIgaGVpZ2h0PSIzMDIiIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSIzNTkiIHk9IjE3MCIgd2lkdGg9IjU2IiBoZWlnaHQ9IjI2NCIgZmlsbD0iIzQ0RjJGNiIvPjxyZWN0IHg9IjE4OCIgeT0iMTcwIiB3aWR0aD0iNTciIGhlaWdodD0iMjY0IiBmaWxsPSIjNDRGMkY2Ii8%2BPHJlY3QgeD0iNzYiIHk9IjY2IiB3aWR0aD0iNDciIGhlaWdodD0iNTciIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSI0ODIiIHk9IjY2IiB3aWR0aD0iNDciIGhlaWdodD0iNTciIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSI3NiIgeT0iNDgyIiB3aWR0aD0iNDciIGhlaWdodD0iNTciIGZpbGw9IiM0NEYyRjYiLz48cmVjdCB4PSI0ODIiIHk9IjQ4MiIgd2lkdGg9IjQ3IiBoZWlnaHQ9IjU3IiBmaWxsPSIjNDRGMkY2Ii8%2BPC9nPjwvc3ZnPg%3D%3D)](https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?tag_filter=485&p=visual-autoregressive-modeling-scalable-image)
9
+
10
+
11
+ </div>
12
+ <p align="center" style="font-size: larger;">
13
+ <a href="https://arxiv.org/abs/2404.02905">Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction</a>
14
+ </p>
15
+
16
+ <div>
17
+ <p align="center" style="font-size: larger;">
18
+ <strong>NeurIPS 2024 Oral</strong>
19
+ </p>
20
+ </div>
21
+
22
+ <p align="center">
23
+ <img src="https://github.com/FoundationVision/VAR/assets/39692511/9850df90-20b1-4f29-8592-e3526d16d755" width=95%>
24
+ <p>
25
+
26
+ <br>
27
+
28
+ ## News
29
+
30
+ * **2024-09:** VAR is accepted as **NeurIPS 2024 Oral** Presentation.
31
+ * **2024-04:** [Visual AutoRegressive modeling](https://github.com/FoundationVision/VAR) is released.
32
+
33
+ ## 🕹️ Try and Play with VAR!
34
+
35
+ We provide a [demo website](https://var.vision/demo) for you to play with VAR models and generate images interactively. Enjoy the fun of visual autoregressive modeling!
36
+
37
+ We also provide [demo_sample.ipynb](demo_sample.ipynb) for you to see more technical details about VAR.
38
+
39
+ [//]: # (<p align="center">)
40
+ [//]: # (<img src="https://user-images.githubusercontent.com/39692511/226376648-3f28a1a6-275d-4f88-8f3e-cd1219882488.png" width=50%)
41
+ [//]: # (<p>)
42
+
43
+
44
+ ## What's New?
45
+
46
+ ### 🔥 Introducing VAR: a new paradigm in autoregressive visual generation✨:
47
+
48
+ Visual Autoregressive Modeling (VAR) redefines the autoregressive learning on images as coarse-to-fine "next-scale prediction" or "next-resolution prediction", diverging from the standard raster-scan "next-token prediction".
49
+
50
+ <p align="center">
51
+ <img src="https://github.com/FoundationVision/VAR/assets/39692511/3e12655c-37dc-4528-b923-ec6c4cfef178" width=93%>
52
+ <p>
53
+
54
+ ### 🔥 For the first time, GPT-style autoregressive models surpass diffusion models🚀:
55
+ <p align="center">
56
+ <img src="https://github.com/FoundationVision/VAR/assets/39692511/cc30b043-fa4e-4d01-a9b1-e50650d5675d" width=55%>
57
+ <p>
58
+
59
+
60
+ ### 🔥 Discovering power-law Scaling Laws in VAR transformers📈:
61
+
62
+
63
+ <p align="center">
64
+ <img src="https://github.com/FoundationVision/VAR/assets/39692511/c35fb56e-896e-4e4b-9fb9-7a1c38513804" width=85%>
65
+ <p>
66
+ <p align="center">
67
+ <img src="https://github.com/FoundationVision/VAR/assets/39692511/91d7b92c-8fc3-44d9-8fb4-73d6cdb8ec1e" width=85%>
68
+ <p>
69
+
70
+
71
+ ### 🔥 Zero-shot generalizability🛠️:
72
+
73
+ <p align="center">
74
+ <img src="https://github.com/FoundationVision/VAR/assets/39692511/a54a4e52-6793-4130-bae2-9e459a08e96a" width=70%>
75
+ <p>
76
+
77
+ #### For a deep dive into our analyses, discussions, and evaluations, check out our [paper](https://arxiv.org/abs/2404.02905).
78
+
79
+
80
+ ## VAR zoo
81
+ We provide VAR models for you to play with, which are on <a href='https://huggingface.co/FoundationVision/var'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Huggingface-FoundationVision/var-yellow'></a> or can be downloaded from the following links:
82
+
83
+ | model | reso. | FID | rel. cost | #params | HF weights🤗 |
84
+ |:----------:|:-----:|:--------:|:---------:|:-------:|:------------------------------------------------------------------------------------|
85
+ | VAR-d16 | 256 | 3.55 | 0.4 | 310M | [var_d16.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d16.pth) |
86
+ | VAR-d20 | 256 | 2.95 | 0.5 | 600M | [var_d20.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d20.pth) |
87
+ | VAR-d24 | 256 | 2.33 | 0.6 | 1.0B | [var_d24.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d24.pth) |
88
+ | VAR-d30 | 256 | 1.97 | 1 | 2.0B | [var_d30.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d30.pth) |
89
+ | VAR-d30-re | 256 | **1.80** | 1 | 2.0B | [var_d30.pth](https://huggingface.co/FoundationVision/var/resolve/main/var_d30.pth) |
90
+
91
+ You can load these models to generate images via the codes in [demo_sample.ipynb](demo_sample.ipynb). Note: you need to download [vae_ch160v4096z32.pth](https://huggingface.co/FoundationVision/var/resolve/main/vae_ch160v4096z32.pth) first.
92
+
93
+
94
+ ## Installation
95
+
96
+ 1. Install `torch>=2.0.0`.
97
+ 2. Install other pip packages via `pip3 install -r requirements.txt`.
98
+ 3. Prepare the [ImageNet](http://image-net.org/) dataset
99
+ <details>
100
+ <summary> assume the ImageNet is in `/path/to/imagenet`. It should be like this:</summary>
101
+
102
+ ```
103
+ /path/to/imagenet/:
104
+ train/:
105
+ n01440764:
106
+ many_images.JPEG ...
107
+ n01443537:
108
+ many_images.JPEG ...
109
+ val/:
110
+ n01440764:
111
+ ILSVRC2012_val_00000293.JPEG ...
112
+ n01443537:
113
+ ILSVRC2012_val_00000236.JPEG ...
114
+ ```
115
+ **NOTE: The arg `--data_path=/path/to/imagenet` should be passed to the training script.**
116
+ </details>
117
+
118
+ 5. (Optional) install and compile `flash-attn` and `xformers` for faster attention computation. Our code will automatically use them if installed. See [models/basic_var.py#L15-L30](models/basic_var.py#L15-L30).
119
+
120
+
121
+ ## Training Scripts
122
+
123
+ To train VAR-{d16, d20, d24, d30, d36-s} on ImageNet 256x256 or 512x512, you can run the following command:
124
+ ```shell
125
+ # d16, 256x256
126
+ torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
127
+ --depth=16 --bs=768 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1
128
+ # d20, 256x256
129
+ torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
130
+ --depth=20 --bs=768 --ep=250 --fp16=1 --alng=1e-3 --wpe=0.1
131
+ # d24, 256x256
132
+ torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
133
+ --depth=24 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-4 --wpe=0.01
134
+ # d30, 256x256
135
+ torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
136
+ --depth=30 --bs=1024 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-5 --wpe=0.01 --twde=0.08
137
+ # d36-s, 512x512 (-s means saln=1, shared AdaLN)
138
+ torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
139
+ --depth=36 --saln=1 --pn=512 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=5e-6 --wpe=0.01 --twde=0.08
140
+ ```
141
+ A folder named `local_output` will be created to save the checkpoints and logs.
142
+ You can monitor the training process by checking the logs in `local_output/log.txt` and `local_output/stdout.txt`, or using `tensorboard --logdir=local_output/`.
143
+
144
+ If your experiment is interrupted, just rerun the command, and the training will **automatically resume** from the last checkpoint in `local_output/ckpt*.pth` (see [utils/misc.py#L344-L357](utils/misc.py#L344-L357)).
145
+
146
+ ## Sampling & Zero-shot Inference
147
+
148
+ For FID evaluation, use `var.autoregressive_infer_cfg(..., cfg=1.5, top_p=0.96, top_k=900, more_smooth=False)` to sample 50,000 images (50 per class) and save them as PNG (not JPEG) files in a folder. Pack them into a `.npz` file via `create_npz_from_sample_folder(sample_folder)` in [utils/misc.py#L344](utils/misc.py#L360).
149
+ Then use the [OpenAI's FID evaluation toolkit](https://github.com/openai/guided-diffusion/tree/main/evaluations) and reference ground truth npz file of [256x256](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz) or [512x512](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz) to evaluate FID, IS, precision, and recall.
150
+
151
+ Note a relatively small `cfg=1.5` is used for trade-off between image quality and diversity. You can adjust it to `cfg=5.0`, or sample with `autoregressive_infer_cfg(..., more_smooth=True)` for **better visual quality**.
152
+ We'll provide the sampling script later.
153
+
154
+ ## License
155
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
156
+
157
+
158
+ ## Citation
159
+ If our work assists your research, feel free to give us a star ⭐ or cite us using:
160
+ ```
161
+ @Article{VAR,
162
+ title={Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction},
163
+ author={Keyu Tian and Yi Jiang and Zehuan Yuan and Bingyue Peng and Liwei Wang},
164
+ year={2024},
165
+ eprint={2404.02905},
166
+ archivePrefix={arXiv},
167
+ primaryClass={cs.CV}
168
+ }
169
+ ```
VAR/demo_sample.ipynb ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "source": [
6
+ "### 🚀 For an interactive experience, head over to our [demo platform](https://var.vision/demo) and dive right in! 🌟"
7
+ ],
8
+ "metadata": {
9
+ "collapsed": false
10
+ }
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 3,
15
+ "outputs": [
16
+ {
17
+ "name": "stderr",
18
+ "output_type": "stream",
19
+ "text": [
20
+ "/var/folders/xv/4sfwyf3j1q72_7wzmsc7s2t00000gn/T/ipykernel_17646/3181200107.py:32: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
21
+ " vae.load_state_dict(torch.load(vae_ckpt, map_location=device), strict=True)\n",
22
+ "/var/folders/xv/4sfwyf3j1q72_7wzmsc7s2t00000gn/T/ipykernel_17646/3181200107.py:33: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
23
+ " var.load_state_dict(torch.load(var_ckpt, map_location=device), strict=True)\n"
24
+ ]
25
+ },
26
+ {
27
+ "name": "stdout",
28
+ "output_type": "stream",
29
+ "text": [
30
+ "prepare finished.\n"
31
+ ]
32
+ }
33
+ ],
34
+ "source": [
35
+ "################## 1. Download checkpoints and build models\n",
36
+ "import os\n",
37
+ "import os.path as osp\n",
38
+ "import torch, torchvision\n",
39
+ "import random\n",
40
+ "import numpy as np\n",
41
+ "import PIL.Image as PImage, PIL.ImageDraw as PImageDraw\n",
42
+ "setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed\n",
43
+ "from models import VQVAE, build_vae_var\n",
44
+ "\n",
45
+ "MODEL_DEPTH = 16 # TODO: =====> please specify MODEL_DEPTH <=====\n",
46
+ "assert MODEL_DEPTH in {16, 20, 24, 30}\n",
47
+ "\n",
48
+ "\n",
49
+ "# download checkpoint\n",
50
+ "hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'\n",
51
+ "vae_ckpt, var_ckpt = '/Users/mac/Downloads/vae_ch160v4096z32.pth', '/Users/mac/Downloads/var_d16.pth'\n",
52
+ "if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}')\n",
53
+ "if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}')\n",
54
+ "\n",
55
+ "# build vae, var\n",
56
+ "patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)\n",
57
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
58
+ "if 'vae' not in globals() or 'var' not in globals():\n",
59
+ " vae, var = build_vae_var(\n",
60
+ " V=4096, Cvae=32, ch=160, share_quant_resi=4, # hard-coded VQVAE hyperparameters\n",
61
+ " device=device, patch_nums=patch_nums,\n",
62
+ " num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,\n",
63
+ " )\n",
64
+ "\n",
65
+ "# load checkpoints\n",
66
+ "vae.load_state_dict(torch.load(vae_ckpt, map_location=device), strict=True)\n",
67
+ "var.load_state_dict(torch.load(var_ckpt, map_location=device), strict=True)\n",
68
+ "vae.eval(), var.eval()\n",
69
+ "for p in vae.parameters(): p.requires_grad_(False)\n",
70
+ "for p in var.parameters(): p.requires_grad_(False)\n",
71
+ "print(f'prepare finished.')"
72
+ ],
73
+ "metadata": {
74
+ "collapsed": false,
75
+ "ExecuteTime": {
76
+ "end_time": "2024-11-09T23:44:58.680070Z",
77
+ "start_time": "2024-11-09T23:44:58.094258Z"
78
+ }
79
+ }
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": 4,
84
+ "outputs": [
85
+ {
86
+ "ename": "RuntimeError",
87
+ "evalue": "Placeholder storage has not been allocated on MPS device!",
88
+ "output_type": "error",
89
+ "traceback": [
90
+ "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
91
+ "\u001B[0;31mRuntimeError\u001B[0m Traceback (most recent call last)",
92
+ "Cell \u001B[0;32mIn[4], line 28\u001B[0m\n\u001B[1;32m 26\u001B[0m label_B: torch\u001B[38;5;241m.\u001B[39mLongTensor \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mtensor(class_labels, device\u001B[38;5;241m=\u001B[39mdevice)\n\u001B[1;32m 27\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m torch\u001B[38;5;241m.\u001B[39minference_mode():\n\u001B[0;32m---> 28\u001B[0m recon_B3HW \u001B[38;5;241m=\u001B[39m \u001B[43mvar\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mautoregressive_infer_cfg\u001B[49m\u001B[43m(\u001B[49m\u001B[43mB\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mB\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mlabel_B\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mlabel_B\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcfg\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mcfg\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtop_k\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m900\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtop_p\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m0.95\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mg_seed\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mseed\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmore_smooth\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mmore_smooth\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 30\u001B[0m chw \u001B[38;5;241m=\u001B[39m torchvision\u001B[38;5;241m.\u001B[39mutils\u001B[38;5;241m.\u001B[39mmake_grid(recon_B3HW, nrow\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m8\u001B[39m, padding\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0\u001B[39m, pad_value\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1.0\u001B[39m)\n\u001B[1;32m 31\u001B[0m chw \u001B[38;5;241m=\u001B[39m chw\u001B[38;5;241m.\u001B[39mpermute(\u001B[38;5;241m1\u001B[39m, \u001B[38;5;241m2\u001B[39m, \u001B[38;5;241m0\u001B[39m)\u001B[38;5;241m.\u001B[39mmul_(\u001B[38;5;241m255\u001B[39m)\u001B[38;5;241m.\u001B[39mcpu()\u001B[38;5;241m.\u001B[39mnumpy()\n",
93
+ "File \u001B[0;32m~/PycharmProjects/ALL_Shit/venv/lib/python3.9/site-packages/torch/utils/_contextlib.py:116\u001B[0m, in \u001B[0;36mcontext_decorator.<locals>.decorate_context\u001B[0;34m(*args, **kwargs)\u001B[0m\n\u001B[1;32m 113\u001B[0m \u001B[38;5;129m@functools\u001B[39m\u001B[38;5;241m.\u001B[39mwraps(func)\n\u001B[1;32m 114\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mdecorate_context\u001B[39m(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[1;32m 115\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m ctx_factory():\n\u001B[0;32m--> 116\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
94
+ "File \u001B[0;32m~/PycharmProjects/VAR_clip/VAR/models/var.py:154\u001B[0m, in \u001B[0;36mVAR.autoregressive_infer_cfg\u001B[0;34m(self, B, label_B, cond_delta, g_seed, cfg, top_k, top_p, beta, more_smooth)\u001B[0m\n\u001B[1;32m 151\u001B[0m \u001B[38;5;28;01melif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(label_B, \u001B[38;5;28mint\u001B[39m):\n\u001B[1;32m 152\u001B[0m label_B \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mfull((B,), fill_value\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mnum_classes \u001B[38;5;28;01mif\u001B[39;00m label_B \u001B[38;5;241m<\u001B[39m \u001B[38;5;241m0\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m label_B, device\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mlvl_1L\u001B[38;5;241m.\u001B[39mdevice)\n\u001B[0;32m--> 154\u001B[0m sos \u001B[38;5;241m=\u001B[39m cond_BD \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mclass_emb\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mcat\u001B[49m\u001B[43m(\u001B[49m\u001B[43m(\u001B[49m\u001B[43mlabel_B\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfull_like\u001B[49m\u001B[43m(\u001B[49m\u001B[43mlabel_B\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mfill_value\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mnum_classes\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 155\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m cond_delta \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m 156\u001B[0m cond_BD[:B] \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m cond_delta \u001B[38;5;241m*\u001B[39m beta\n",
95
+ "File \u001B[0;32m~/PycharmProjects/ALL_Shit/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1553\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[0;34m(self, *args, **kwargs)\u001B[0m\n\u001B[1;32m 1551\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[1;32m 1552\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m-> 1553\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
96
+ "File \u001B[0;32m~/PycharmProjects/ALL_Shit/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1562\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[0;34m(self, *args, **kwargs)\u001B[0m\n\u001B[1;32m 1557\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[1;32m 1558\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[1;32m 1559\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[1;32m 1560\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[1;32m 1561\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[0;32m-> 1562\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1564\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[1;32m 1565\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n",
97
+ "File \u001B[0;32m~/PycharmProjects/ALL_Shit/venv/lib/python3.9/site-packages/torch/nn/modules/sparse.py:164\u001B[0m, in \u001B[0;36mEmbedding.forward\u001B[0;34m(self, input)\u001B[0m\n\u001B[1;32m 163\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mforward\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;28minput\u001B[39m: Tensor) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m Tensor:\n\u001B[0;32m--> 164\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mF\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43membedding\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 165\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43minput\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mweight\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mpadding_idx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mmax_norm\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 166\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mnorm_type\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mscale_grad_by_freq\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msparse\u001B[49m\u001B[43m)\u001B[49m\n",
98
+ "File \u001B[0;32m~/PycharmProjects/ALL_Shit/venv/lib/python3.9/site-packages/torch/nn/functional.py:2267\u001B[0m, in \u001B[0;36membedding\u001B[0;34m(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)\u001B[0m\n\u001B[1;32m 2261\u001B[0m \u001B[38;5;66;03m# Note [embedding_renorm set_grad_enabled]\u001B[39;00m\n\u001B[1;32m 2262\u001B[0m \u001B[38;5;66;03m# XXX: equivalent to\u001B[39;00m\n\u001B[1;32m 2263\u001B[0m \u001B[38;5;66;03m# with torch.no_grad():\u001B[39;00m\n\u001B[1;32m 2264\u001B[0m \u001B[38;5;66;03m# torch.embedding_renorm_\u001B[39;00m\n\u001B[1;32m 2265\u001B[0m \u001B[38;5;66;03m# remove once script supports set_grad_enabled\u001B[39;00m\n\u001B[1;32m 2266\u001B[0m _no_grad_embedding_renorm_(weight, \u001B[38;5;28minput\u001B[39m, max_norm, norm_type)\n\u001B[0;32m-> 2267\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43membedding\u001B[49m\u001B[43m(\u001B[49m\u001B[43mweight\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43minput\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mpadding_idx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mscale_grad_by_freq\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msparse\u001B[49m\u001B[43m)\u001B[49m\n",
99
+ "\u001B[0;31mRuntimeError\u001B[0m: Placeholder storage has not been allocated on MPS device!"
100
+ ]
101
+ }
102
+ ],
103
+ "source": [
104
+ "############################# 2. Sample with classifier-free guidance\n",
105
+ "\n",
106
+ "# set args\n",
107
+ "seed = 0 #@param {type:\"number\"}\n",
108
+ "torch.manual_seed(seed)\n",
109
+ "num_sampling_steps = 250 #@param {type:\"slider\", min:0, max:1000, step:1}\n",
110
+ "cfg = 4 #@param {type:\"slider\", min:1, max:10, step:0.1}\n",
111
+ "class_labels = (980, 980, 437, 437, 22, 22, 562, 562) #@param {type:\"raw\"}\n",
112
+ "more_smooth = False # True for more smooth output\n",
113
+ "\n",
114
+ "# seed\n",
115
+ "torch.manual_seed(seed)\n",
116
+ "random.seed(seed)\n",
117
+ "np.random.seed(seed)\n",
118
+ "torch.backends.cudnn.deterministic = True\n",
119
+ "torch.backends.cudnn.benchmark = False\n",
120
+ "\n",
121
+ "# run faster\n",
122
+ "tf32 = True\n",
123
+ "torch.backends.cudnn.allow_tf32 = bool(tf32)\n",
124
+ "torch.backends.cuda.matmul.allow_tf32 = bool(tf32)\n",
125
+ "torch.set_float32_matmul_precision('high' if tf32 else 'highest')\n",
126
+ "\n",
127
+ "# sample\n",
128
+ "B = len(class_labels)\n",
129
+ "label_B: torch.LongTensor = torch.tensor(class_labels, device=device)\n",
130
+ "with torch.inference_mode():\n",
131
+ " recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth)\n",
132
+ "\n",
133
+ "chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0)\n",
134
+ "chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()\n",
135
+ "chw = PImage.fromarray(chw.astype(np.uint8))\n",
136
+ "chw.show()\n"
137
+ ],
138
+ "metadata": {
139
+ "collapsed": false,
140
+ "ExecuteTime": {
141
+ "end_time": "2024-11-09T23:44:59.930244Z",
142
+ "start_time": "2024-11-09T23:44:59.454228Z"
143
+ }
144
+ }
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "outputs": [],
149
+ "source": [],
150
+ "metadata": {
151
+ "collapsed": false
152
+ }
153
+ }
154
+ ],
155
+ "metadata": {
156
+ "kernelspec": {
157
+ "display_name": "Python 3",
158
+ "language": "python",
159
+ "name": "python3"
160
+ },
161
+ "language_info": {
162
+ "codemirror_mode": {
163
+ "name": "ipython",
164
+ "version": 2
165
+ },
166
+ "file_extension": ".py",
167
+ "mimetype": "text/x-python",
168
+ "name": "python",
169
+ "nbconvert_exporter": "python",
170
+ "pygments_lexer": "ipython2",
171
+ "version": "2.7.6"
172
+ }
173
+ },
174
+ "nbformat": 4,
175
+ "nbformat_minor": 0
176
+ }
VAR/dist.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import functools
3
+ import os
4
+ import sys
5
+ from typing import List
6
+ from typing import Union
7
+
8
+ import torch
9
+ import torch.distributed as tdist
10
+ import torch.multiprocessing as mp
11
+
12
+ __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu'
13
+ __initialized = False
14
+
15
+
16
+ def initialized():
17
+ return __initialized
18
+
19
+
20
+ def initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout=30):
21
+ global __device
22
+ if not torch.cuda.is_available():
23
+ print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
24
+ return
25
+ elif 'RANK' not in os.environ:
26
+ torch.cuda.set_device(gpu_id_if_not_distibuted)
27
+ __device = torch.empty(1).cuda().device
28
+ print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr)
29
+ return
30
+ # then 'RANK' must exist
31
+ global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
32
+ local_rank = global_rank % num_gpus
33
+ torch.cuda.set_device(local_rank)
34
+
35
+ # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
36
+ if mp.get_start_method(allow_none=True) is None:
37
+ method = 'fork' if fork else 'spawn'
38
+ print(f'[dist initialize] mp method={method}')
39
+ mp.set_start_method(method)
40
+ tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout*60))
41
+
42
+ global __rank, __local_rank, __world_size, __initialized
43
+ __local_rank = local_rank
44
+ __rank, __world_size = tdist.get_rank(), tdist.get_world_size()
45
+ __device = torch.empty(1).cuda().device
46
+ __initialized = True
47
+
48
+ assert tdist.is_initialized(), 'torch.distributed is not initialized!'
49
+ print(f'[lrk={get_local_rank()}, rk={get_rank()}]')
50
+
51
+
52
+ def get_rank():
53
+ return __rank
54
+
55
+
56
+ def get_local_rank():
57
+ return __local_rank
58
+
59
+
60
+ def get_world_size():
61
+ return __world_size
62
+
63
+
64
+ def get_device():
65
+ return __device
66
+
67
+
68
+ def set_gpu_id(gpu_id: int):
69
+ if gpu_id is None: return
70
+ global __device
71
+ if isinstance(gpu_id, (str, int)):
72
+ torch.cuda.set_device(int(gpu_id))
73
+ __device = torch.empty(1).cuda().device
74
+ else:
75
+ raise NotImplementedError
76
+
77
+
78
+ def is_master():
79
+ return __rank == 0
80
+
81
+
82
+ def is_local_master():
83
+ return __local_rank == 0
84
+
85
+
86
+ def new_group(ranks: List[int]):
87
+ if __initialized:
88
+ return tdist.new_group(ranks=ranks)
89
+ return None
90
+
91
+
92
+ def barrier():
93
+ if __initialized:
94
+ tdist.barrier()
95
+
96
+
97
+ def allreduce(t: torch.Tensor, async_op=False):
98
+ if __initialized:
99
+ if not t.is_cuda:
100
+ cu = t.detach().cuda()
101
+ ret = tdist.all_reduce(cu, async_op=async_op)
102
+ t.copy_(cu.cpu())
103
+ else:
104
+ ret = tdist.all_reduce(t, async_op=async_op)
105
+ return ret
106
+ return None
107
+
108
+
109
+ def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
110
+ if __initialized:
111
+ if not t.is_cuda:
112
+ t = t.cuda()
113
+ ls = [torch.empty_like(t) for _ in range(__world_size)]
114
+ tdist.all_gather(ls, t)
115
+ else:
116
+ ls = [t]
117
+ if cat:
118
+ ls = torch.cat(ls, dim=0)
119
+ return ls
120
+
121
+
122
+ def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
123
+ if __initialized:
124
+ if not t.is_cuda:
125
+ t = t.cuda()
126
+
127
+ t_size = torch.tensor(t.size(), device=t.device)
128
+ ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]
129
+ tdist.all_gather(ls_size, t_size)
130
+
131
+ max_B = max(size[0].item() for size in ls_size)
132
+ pad = max_B - t_size[0].item()
133
+ if pad:
134
+ pad_size = (pad, *t.size()[1:])
135
+ t = torch.cat((t, t.new_empty(pad_size)), dim=0)
136
+
137
+ ls_padded = [torch.empty_like(t) for _ in range(__world_size)]
138
+ tdist.all_gather(ls_padded, t)
139
+ ls = []
140
+ for t, size in zip(ls_padded, ls_size):
141
+ ls.append(t[:size[0].item()])
142
+ else:
143
+ ls = [t]
144
+ if cat:
145
+ ls = torch.cat(ls, dim=0)
146
+ return ls
147
+
148
+
149
+ def broadcast(t: torch.Tensor, src_rank) -> None:
150
+ if __initialized:
151
+ if not t.is_cuda:
152
+ cu = t.detach().cuda()
153
+ tdist.broadcast(cu, src=src_rank)
154
+ t.copy_(cu.cpu())
155
+ else:
156
+ tdist.broadcast(t, src=src_rank)
157
+
158
+
159
+ def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
160
+ if not initialized():
161
+ return torch.tensor([val]) if fmt is None else [fmt % val]
162
+
163
+ ts = torch.zeros(__world_size)
164
+ ts[__rank] = val
165
+ allreduce(ts)
166
+ if fmt is None:
167
+ return ts
168
+ return [fmt % v for v in ts.cpu().numpy().tolist()]
169
+
170
+
171
+ def master_only(func):
172
+ @functools.wraps(func)
173
+ def wrapper(*args, **kwargs):
174
+ force = kwargs.pop('force', False)
175
+ if force or is_master():
176
+ ret = func(*args, **kwargs)
177
+ else:
178
+ ret = None
179
+ barrier()
180
+ return ret
181
+ return wrapper
182
+
183
+
184
+ def local_master_only(func):
185
+ @functools.wraps(func)
186
+ def wrapper(*args, **kwargs):
187
+ force = kwargs.pop('force', False)
188
+ if force or is_local_master():
189
+ ret = func(*args, **kwargs)
190
+ else:
191
+ ret = None
192
+ barrier()
193
+ return ret
194
+ return wrapper
195
+
196
+
197
+ def for_visualize(func):
198
+ @functools.wraps(func)
199
+ def wrapper(*args, **kwargs):
200
+ if is_master():
201
+ # with torch.no_grad():
202
+ ret = func(*args, **kwargs)
203
+ else:
204
+ ret = None
205
+ return ret
206
+ return wrapper
207
+
208
+
209
+ def finalize():
210
+ if __initialized:
211
+ tdist.destroy_process_group()
VAR/models/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch.nn as nn
3
+
4
+ from .quant import VectorQuantizer2
5
+ from .var import VAR
6
+ from .vqvae import VQVAE
7
+
8
+
9
+ def build_vae_var(
10
+ # Shared args
11
+ device, patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
12
+ # VQVAE args
13
+ V=4096, Cvae=32, ch=160, share_quant_resi=4,
14
+ # VAR args
15
+ num_classes=1000, depth=16, shared_aln=False, attn_l2_norm=True,
16
+ flash_if_available=True, fused_if_available=True,
17
+ init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=-1, # init_std < 0: automated
18
+ ) -> Tuple[VQVAE, VAR]:
19
+ heads = depth
20
+ width = depth * 64
21
+ dpr = 0.1 * depth/24
22
+
23
+ # disable built-in initialization for speed
24
+ for clz in (nn.Linear, nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm, nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d, nn.ConvTranspose2d):
25
+ setattr(clz, 'reset_parameters', lambda self: None)
26
+
27
+ # build models
28
+ vae_local = VQVAE(vocab_size=V, z_channels=Cvae, ch=ch, test_mode=True, share_quant_resi=share_quant_resi, v_patch_nums=patch_nums).to(device)
29
+ var_wo_ddp = VAR(
30
+ vae_local=vae_local,
31
+ num_classes=num_classes, depth=depth, embed_dim=width, num_heads=heads, drop_rate=0., attn_drop_rate=0., drop_path_rate=dpr,
32
+ norm_eps=1e-6, shared_aln=shared_aln, cond_drop_rate=0.1,
33
+ attn_l2_norm=attn_l2_norm,
34
+ patch_nums=patch_nums,
35
+ flash_if_available=flash_if_available, fused_if_available=fused_if_available,
36
+ ).to(device)
37
+ var_wo_ddp.init_weights(init_adaln=init_adaln, init_adaln_gamma=init_adaln_gamma, init_head=init_head, init_std=init_std)
38
+
39
+ return vae_local, var_wo_ddp
VAR/models/basic_vae.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ # this file only provides the 2 modules used in VQVAE
7
+ __all__ = ['Encoder', 'Decoder',]
8
+
9
+
10
+ """
11
+ References: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py
12
+ """
13
+ # swish
14
+ def nonlinearity(x):
15
+ return x * torch.sigmoid(x)
16
+
17
+
18
+ def Normalize(in_channels, num_groups=32):
19
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
20
+
21
+
22
+ class Upsample2x(nn.Module):
23
+ def __init__(self, in_channels):
24
+ super().__init__()
25
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
26
+
27
+ def forward(self, x):
28
+ return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
29
+
30
+
31
+ class Downsample2x(nn.Module):
32
+ def __init__(self, in_channels):
33
+ super().__init__()
34
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
35
+
36
+ def forward(self, x):
37
+ return self.conv(F.pad(x, pad=(0, 1, 0, 1), mode='constant', value=0))
38
+
39
+
40
+ class ResnetBlock(nn.Module):
41
+ def __init__(self, *, in_channels, out_channels=None, dropout): # conv_shortcut=False, # conv_shortcut: always False in VAE
42
+ super().__init__()
43
+ self.in_channels = in_channels
44
+ out_channels = in_channels if out_channels is None else out_channels
45
+ self.out_channels = out_channels
46
+
47
+ self.norm1 = Normalize(in_channels)
48
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
49
+ self.norm2 = Normalize(out_channels)
50
+ self.dropout = torch.nn.Dropout(dropout) if dropout > 1e-6 else nn.Identity()
51
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
52
+ if self.in_channels != self.out_channels:
53
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
54
+ else:
55
+ self.nin_shortcut = nn.Identity()
56
+
57
+ def forward(self, x):
58
+ h = self.conv1(F.silu(self.norm1(x), inplace=True))
59
+ h = self.conv2(self.dropout(F.silu(self.norm2(h), inplace=True)))
60
+ return self.nin_shortcut(x) + h
61
+
62
+
63
+ class AttnBlock(nn.Module):
64
+ def __init__(self, in_channels):
65
+ super().__init__()
66
+ self.C = in_channels
67
+
68
+ self.norm = Normalize(in_channels)
69
+ self.qkv = torch.nn.Conv2d(in_channels, 3*in_channels, kernel_size=1, stride=1, padding=0)
70
+ self.w_ratio = int(in_channels) ** (-0.5)
71
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
72
+
73
+ def forward(self, x):
74
+ qkv = self.qkv(self.norm(x))
75
+ B, _, H, W = qkv.shape # should be B,3C,H,W
76
+ C = self.C
77
+ q, k, v = qkv.reshape(B, 3, C, H, W).unbind(1)
78
+
79
+ # compute attention
80
+ q = q.view(B, C, H * W).contiguous()
81
+ q = q.permute(0, 2, 1).contiguous() # B,HW,C
82
+ k = k.view(B, C, H * W).contiguous() # B,C,HW
83
+ w = torch.bmm(q, k).mul_(self.w_ratio) # B,HW,HW w[B,i,j]=sum_c q[B,i,C]k[B,C,j]
84
+ w = F.softmax(w, dim=2)
85
+
86
+ # attend to values
87
+ v = v.view(B, C, H * W).contiguous()
88
+ w = w.permute(0, 2, 1).contiguous() # B,HW,HW (first HW of k, second of q)
89
+ h = torch.bmm(v, w) # B, C,HW (HW of q) h[B,C,j] = sum_i v[B,C,i] w[B,i,j]
90
+ h = h.view(B, C, H, W).contiguous()
91
+
92
+ return x + self.proj_out(h)
93
+
94
+
95
+ def make_attn(in_channels, using_sa=True):
96
+ return AttnBlock(in_channels) if using_sa else nn.Identity()
97
+
98
+
99
+ class Encoder(nn.Module):
100
+ def __init__(
101
+ self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,
102
+ dropout=0.0, in_channels=3,
103
+ z_channels, double_z=False, using_sa=True, using_mid_sa=True,
104
+ ):
105
+ super().__init__()
106
+ self.ch = ch
107
+ self.num_resolutions = len(ch_mult)
108
+ self.downsample_ratio = 2 ** (self.num_resolutions - 1)
109
+ self.num_res_blocks = num_res_blocks
110
+ self.in_channels = in_channels
111
+
112
+ # downsampling
113
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
114
+
115
+ in_ch_mult = (1,) + tuple(ch_mult)
116
+ self.down = nn.ModuleList()
117
+ for i_level in range(self.num_resolutions):
118
+ block = nn.ModuleList()
119
+ attn = nn.ModuleList()
120
+ block_in = ch * in_ch_mult[i_level]
121
+ block_out = ch * ch_mult[i_level]
122
+ for i_block in range(self.num_res_blocks):
123
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))
124
+ block_in = block_out
125
+ if i_level == self.num_resolutions - 1 and using_sa:
126
+ attn.append(make_attn(block_in, using_sa=True))
127
+ down = nn.Module()
128
+ down.block = block
129
+ down.attn = attn
130
+ if i_level != self.num_resolutions - 1:
131
+ down.downsample = Downsample2x(block_in)
132
+ self.down.append(down)
133
+
134
+ # middle
135
+ self.mid = nn.Module()
136
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
137
+ self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
138
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
139
+
140
+ # end
141
+ self.norm_out = Normalize(block_in)
142
+ self.conv_out = torch.nn.Conv2d(block_in, (2 * z_channels if double_z else z_channels), kernel_size=3, stride=1, padding=1)
143
+
144
+ def forward(self, x):
145
+ # downsampling
146
+ h = self.conv_in(x)
147
+ for i_level in range(self.num_resolutions):
148
+ for i_block in range(self.num_res_blocks):
149
+ h = self.down[i_level].block[i_block](h)
150
+ if len(self.down[i_level].attn) > 0:
151
+ h = self.down[i_level].attn[i_block](h)
152
+ if i_level != self.num_resolutions - 1:
153
+ h = self.down[i_level].downsample(h)
154
+
155
+ # middle
156
+ h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(h)))
157
+
158
+ # end
159
+ h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
160
+ return h
161
+
162
+
163
+ class Decoder(nn.Module):
164
+ def __init__(
165
+ self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,
166
+ dropout=0.0, in_channels=3, # in_channels: raw img channels
167
+ z_channels, using_sa=True, using_mid_sa=True,
168
+ ):
169
+ super().__init__()
170
+ self.ch = ch
171
+ self.num_resolutions = len(ch_mult)
172
+ self.num_res_blocks = num_res_blocks
173
+ self.in_channels = in_channels
174
+
175
+ # compute in_ch_mult, block_in and curr_res at lowest res
176
+ in_ch_mult = (1,) + tuple(ch_mult)
177
+ block_in = ch * ch_mult[self.num_resolutions - 1]
178
+
179
+ # z to block_in
180
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
181
+
182
+ # middle
183
+ self.mid = nn.Module()
184
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
185
+ self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
186
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
187
+
188
+ # upsampling
189
+ self.up = nn.ModuleList()
190
+ for i_level in reversed(range(self.num_resolutions)):
191
+ block = nn.ModuleList()
192
+ attn = nn.ModuleList()
193
+ block_out = ch * ch_mult[i_level]
194
+ for i_block in range(self.num_res_blocks + 1):
195
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))
196
+ block_in = block_out
197
+ if i_level == self.num_resolutions-1 and using_sa:
198
+ attn.append(make_attn(block_in, using_sa=True))
199
+ up = nn.Module()
200
+ up.block = block
201
+ up.attn = attn
202
+ if i_level != 0:
203
+ up.upsample = Upsample2x(block_in)
204
+ self.up.insert(0, up) # prepend to get consistent order
205
+
206
+ # end
207
+ self.norm_out = Normalize(block_in)
208
+ self.conv_out = torch.nn.Conv2d(block_in, in_channels, kernel_size=3, stride=1, padding=1)
209
+
210
+ def forward(self, z):
211
+ # z to block_in
212
+ # middle
213
+ h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(self.conv_in(z))))
214
+
215
+ # upsampling
216
+ for i_level in reversed(range(self.num_resolutions)):
217
+ for i_block in range(self.num_res_blocks + 1):
218
+ h = self.up[i_level].block[i_block](h)
219
+ if len(self.up[i_level].attn) > 0:
220
+ h = self.up[i_level].attn[i_block](h)
221
+ if i_level != 0:
222
+ h = self.up[i_level].upsample(h)
223
+
224
+ # end
225
+ h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
226
+ return h
VAR/models/basic_var.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from models.helpers import DropPath, drop_path
8
+
9
+
10
+ # this file only provides the 3 blocks used in VAR transformer
11
+ __all__ = ['FFN', 'AdaLNSelfAttn', 'AdaLNBeforeHead']
12
+
13
+
14
+ # automatically import fused operators
15
+ dropout_add_layer_norm = fused_mlp_func = memory_efficient_attention = flash_attn_func = None
16
+ try:
17
+ from flash_attn.ops.layer_norm import dropout_add_layer_norm
18
+ from flash_attn.ops.fused_dense import fused_mlp_func
19
+ except ImportError: pass
20
+ # automatically import faster attention implementations
21
+ try: from xformers.ops import memory_efficient_attention
22
+ except ImportError: pass
23
+ try: from flash_attn import flash_attn_func # qkv: BLHc, ret: BLHcq
24
+ except ImportError: pass
25
+ try: from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc
26
+ except ImportError:
27
+ def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p=0.0):
28
+ attn = query.mul(scale) @ key.transpose(-2, -1) # BHLc @ BHcL => BHLL
29
+ if attn_mask is not None: attn.add_(attn_mask)
30
+ return (F.dropout(attn.softmax(dim=-1), p=dropout_p, inplace=True) if dropout_p > 0 else attn.softmax(dim=-1)) @ value
31
+
32
+
33
+ class FFN(nn.Module):
34
+ def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_if_available=True):
35
+ super().__init__()
36
+ self.fused_mlp_func = fused_mlp_func if fused_if_available else None
37
+ out_features = out_features or in_features
38
+ hidden_features = hidden_features or in_features
39
+ self.fc1 = nn.Linear(in_features, hidden_features)
40
+ self.act = nn.GELU(approximate='tanh')
41
+ self.fc2 = nn.Linear(hidden_features, out_features)
42
+ self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity()
43
+
44
+ def forward(self, x):
45
+ if self.fused_mlp_func is not None:
46
+ return self.drop(self.fused_mlp_func(
47
+ x=x, weight1=self.fc1.weight, weight2=self.fc2.weight, bias1=self.fc1.bias, bias2=self.fc2.bias,
48
+ activation='gelu_approx', save_pre_act=self.training, return_residual=False, checkpoint_lvl=0,
49
+ heuristic=0, process_group=None,
50
+ ))
51
+ else:
52
+ return self.drop(self.fc2( self.act(self.fc1(x)) ))
53
+
54
+ def extra_repr(self) -> str:
55
+ return f'fused_mlp_func={self.fused_mlp_func is not None}'
56
+
57
+
58
+ class SelfAttention(nn.Module):
59
+ def __init__(
60
+ self, block_idx, embed_dim=768, num_heads=12,
61
+ attn_drop=0., proj_drop=0., attn_l2_norm=False, flash_if_available=True,
62
+ ):
63
+ super().__init__()
64
+ assert embed_dim % num_heads == 0
65
+ self.block_idx, self.num_heads, self.head_dim = block_idx, num_heads, embed_dim // num_heads # =64
66
+ self.attn_l2_norm = attn_l2_norm
67
+ if self.attn_l2_norm:
68
+ self.scale = 1
69
+ self.scale_mul_1H11 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True)
70
+ self.max_scale_mul = torch.log(torch.tensor(100)).item()
71
+ else:
72
+ self.scale = 0.25 / math.sqrt(self.head_dim)
73
+
74
+ self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
75
+ self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim))
76
+ self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
77
+
78
+ self.proj = nn.Linear(embed_dim, embed_dim)
79
+ self.proj_drop = nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity()
80
+ self.attn_drop: float = attn_drop
81
+ self.using_flash = flash_if_available and flash_attn_func is not None
82
+ self.using_xform = flash_if_available and memory_efficient_attention is not None
83
+
84
+ # only used during inference
85
+ self.caching, self.cached_k, self.cached_v = False, None, None
86
+
87
+ def kv_caching(self, enable: bool): self.caching, self.cached_k, self.cached_v = enable, None, None
88
+
89
+ # NOTE: attn_bias is None during inference because kv cache is enabled
90
+ def forward(self, x, attn_bias):
91
+ B, L, C = x.shape
92
+
93
+ qkv = F.linear(input=x, weight=self.mat_qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim)
94
+ main_type = qkv.dtype
95
+ # qkv: BL3Hc
96
+
97
+ using_flash = self.using_flash and attn_bias is None and qkv.dtype != torch.float32
98
+ if using_flash or self.using_xform: q, k, v = qkv.unbind(dim=2); dim_cat = 1 # q or k or v: BLHc
99
+ else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); dim_cat = 2 # q or k or v: BHLc
100
+
101
+ if self.attn_l2_norm:
102
+ scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp()
103
+ if using_flash or self.using_xform: scale_mul = scale_mul.transpose(1, 2) # 1H11 to 11H1
104
+ q = F.normalize(q, dim=-1).mul(scale_mul)
105
+ k = F.normalize(k, dim=-1)
106
+
107
+ if self.caching:
108
+ if self.cached_k is None: self.cached_k = k; self.cached_v = v
109
+ else: k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat); v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat)
110
+
111
+ dropout_p = self.attn_drop if self.training else 0.0
112
+ if using_flash:
113
+ oup = flash_attn_func(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), dropout_p=dropout_p, softmax_scale=self.scale).view(B, L, C)
114
+ elif self.using_xform:
115
+ oup = memory_efficient_attention(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), attn_bias=None if attn_bias is None else attn_bias.to(dtype=main_type).expand(B, self.num_heads, -1, -1), p=dropout_p, scale=self.scale).view(B, L, C)
116
+ else:
117
+ oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias, dropout_p=dropout_p).transpose(1, 2).reshape(B, L, C)
118
+
119
+ return self.proj_drop(self.proj(oup))
120
+ # attn = (q @ k.transpose(-2, -1)).add_(attn_bias + self.local_rpb()) # BHLc @ BHcL => BHLL
121
+ # attn = self.attn_drop(attn.softmax(dim=-1))
122
+ # oup = (attn @ v).transpose_(1, 2).reshape(B, L, -1) # BHLL @ BHLc = BHLc => BLHc => BLC
123
+
124
+ def extra_repr(self) -> str:
125
+ return f'using_flash={self.using_flash}, using_xform={self.using_xform}, attn_l2_norm={self.attn_l2_norm}'
126
+
127
+
128
+ class AdaLNSelfAttn(nn.Module):
129
+ def __init__(
130
+ self, block_idx, last_drop_p, embed_dim, cond_dim, shared_aln: bool, norm_layer,
131
+ num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., attn_l2_norm=False,
132
+ flash_if_available=False, fused_if_available=True,
133
+ ):
134
+ super(AdaLNSelfAttn, self).__init__()
135
+ self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim
136
+ self.C, self.D = embed_dim, cond_dim
137
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
138
+ self.attn = SelfAttention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_l2_norm=attn_l2_norm, flash_if_available=flash_if_available)
139
+ self.ffn = FFN(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio), drop=drop, fused_if_available=fused_if_available)
140
+
141
+ self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
142
+ self.shared_aln = shared_aln
143
+ if self.shared_aln:
144
+ self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
145
+ else:
146
+ lin = nn.Linear(cond_dim, 6*embed_dim)
147
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin)
148
+
149
+ self.fused_add_norm_fn = None
150
+
151
+ # NOTE: attn_bias is None during inference because kv cache is enabled
152
+ def forward(self, x, cond_BD, attn_bias): # C: embed_dim, D: cond_dim
153
+ if self.shared_aln:
154
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
155
+ else:
156
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
157
+ x = x + self.drop_path(self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias=attn_bias ).mul_(gamma1))
158
+ x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed when FusedMLP is used
159
+ return x
160
+
161
+ def extra_repr(self) -> str:
162
+ return f'shared_aln={self.shared_aln}'
163
+
164
+
165
+ class AdaLNBeforeHead(nn.Module):
166
+ def __init__(self, C, D, norm_layer): # C: embed_dim, D: cond_dim
167
+ super().__init__()
168
+ self.C, self.D = C, D
169
+ self.ln_wo_grad = norm_layer(C, elementwise_affine=False)
170
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2*C))
171
+
172
+ def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor):
173
+ scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)
174
+ return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)
VAR/models/helpers.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def sample_with_top_k_top_p_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l)
7
+ B, l, V = logits_BlV.shape
8
+ if top_k > 0:
9
+ idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
10
+ logits_BlV.masked_fill_(idx_to_remove, -torch.inf)
11
+ if top_p > 0:
12
+ sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)
13
+ sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
14
+ sorted_idx_to_remove[..., -1:] = False
15
+ logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf)
16
+ # sample (have to squeeze cuz torch.multinomial can only be used for 2D tensor)
17
+ replacement = num_samples >= 0
18
+ num_samples = abs(num_samples)
19
+ return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
20
+
21
+
22
+ def gumbel_softmax_with_rng(logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1, rng: torch.Generator = None) -> torch.Tensor:
23
+ if rng is None:
24
+ return F.gumbel_softmax(logits=logits, tau=tau, hard=hard, eps=eps, dim=dim)
25
+
26
+ gumbels = (-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_(generator=rng).log())
27
+ gumbels = (logits + gumbels) / tau
28
+ y_soft = gumbels.softmax(dim)
29
+
30
+ if hard:
31
+ index = y_soft.max(dim, keepdim=True)[1]
32
+ y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
33
+ ret = y_hard - y_soft.detach() + y_soft
34
+ else:
35
+ ret = y_soft
36
+ return ret
37
+
38
+
39
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): # taken from timm
40
+ if drop_prob == 0. or not training: return x
41
+ keep_prob = 1 - drop_prob
42
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
43
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
44
+ if keep_prob > 0.0 and scale_by_keep:
45
+ random_tensor.div_(keep_prob)
46
+ return x * random_tensor
47
+
48
+
49
+ class DropPath(nn.Module): # taken from timm
50
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
51
+ super(DropPath, self).__init__()
52
+ self.drop_prob = drop_prob
53
+ self.scale_by_keep = scale_by_keep
54
+
55
+ def forward(self, x):
56
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
57
+
58
+ def extra_repr(self):
59
+ return f'(drop_prob=...)'
VAR/models/quant.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Sequence, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import distributed as tdist, nn as nn
6
+ from torch.nn import functional as F
7
+
8
+ import dist
9
+
10
+ # this file only provides the VectorQuantizer2 used in VQVAE
11
+ __all__ = ['VectorQuantizer2', ]
12
+
13
+
14
+ class VectorQuantizer2(nn.Module):
15
+ # VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25
16
+ def __init__(
17
+ self, vocab_size, Cvae, using_znorm, beta: float = 0.25,
18
+ default_qresi_counts=0, v_patch_nums=None, quant_resi=0.5, share_quant_resi=4, # share_quant_resi: args.qsr
19
+ ):
20
+ super().__init__()
21
+ self.vocab_size: int = vocab_size
22
+ self.Cvae: int = Cvae
23
+ self.using_znorm: bool = using_znorm
24
+ self.v_patch_nums: Tuple[int] = v_patch_nums
25
+
26
+ self.quant_resi_ratio = quant_resi
27
+ if share_quant_resi == 0: # non-shared: \phi_{1 to K} for K scales
28
+ self.quant_resi = PhiNonShared(
29
+ [(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in
30
+ range(default_qresi_counts or len(self.v_patch_nums))])
31
+ elif share_quant_resi == 1: # fully shared: only a single \phi for K scales
32
+ self.quant_resi = PhiShared(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity())
33
+ else: # partially shared: \phi_{1 to share_quant_resi} for K scales
34
+ self.quant_resi = PhiPartiallyShared(nn.ModuleList(
35
+ [(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in
36
+ range(share_quant_resi)]))
37
+
38
+ self.register_buffer('ema_vocab_hit_SV', torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0))
39
+ self.record_hit = 0
40
+
41
+ self.beta: float = beta
42
+ self.embedding = nn.Embedding(self.vocab_size, self.Cvae)
43
+
44
+ # only used for progressive training of VAR (not supported yet, will be tested and supported in the future)
45
+ self.prog_si = -1 # progressive training: not supported yet, prog_si always -1
46
+
47
+ def eini(self, eini):
48
+ if eini > 0:
49
+ nn.init.trunc_normal_(self.embedding.weight.data, std=eini)
50
+ elif eini < 0:
51
+ self.embedding.weight.data.uniform_(-abs(eini) / self.vocab_size, abs(eini) / self.vocab_size)
52
+
53
+ def extra_repr(self) -> str:
54
+ return f'{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta} | S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}'
55
+
56
+ # ===================== `forward` is only used in VAE training =====================
57
+ def forward(self, f_BChw: torch.Tensor, ret_usages=False) -> Tuple[torch.Tensor, List[float], torch.Tensor]:
58
+ dtype = f_BChw.dtype
59
+ if dtype != torch.float32: f_BChw = f_BChw.float()
60
+ B, C, H, W = f_BChw.shape
61
+ f_no_grad = f_BChw.detach()
62
+
63
+ f_rest = f_no_grad.clone()
64
+ f_hat = torch.zeros_like(f_rest)
65
+
66
+ with torch.cuda.amp.autocast(enabled=False):
67
+ mean_vq_loss: torch.Tensor = 0.0
68
+ vocab_hit_V = torch.zeros(self.vocab_size, dtype=torch.float, device=f_BChw.device)
69
+ SN = len(self.v_patch_nums)
70
+ for si, pn in enumerate(self.v_patch_nums): # from small to large
71
+ # find the nearest embedding
72
+ if self.using_znorm:
73
+ rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='bilinear').permute(0, 2, 3, 1).reshape(-1,
74
+ C) if (
75
+ si != SN - 1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
76
+ rest_NC = F.normalize(rest_NC, dim=-1)
77
+ idx_N = torch.argmax(rest_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
78
+ else:
79
+ rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='bilinear').permute(0, 2, 3, 1).reshape(-1,
80
+ C) if (
81
+ si != SN - 1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
82
+ d_no_grad = torch.sum(rest_NC.square(), dim=1, keepdim=True) + torch.sum(
83
+ self.embedding.weight.data.square(), dim=1, keepdim=False)
84
+ d_no_grad.addmm_(rest_NC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
85
+ idx_N = torch.argmin(d_no_grad, dim=1)
86
+
87
+ hit_V = idx_N.bincount(minlength=self.vocab_size).float()
88
+ if self.training:
89
+ if dist.initialized(): handler = tdist.all_reduce(hit_V, async_op=True)
90
+
91
+ # calc loss
92
+ idx_Bhw = idx_N.view(B, pn, pn)
93
+ h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W),
94
+ mode='bilinear').contiguous() if (si != SN - 1) else self.embedding(
95
+ idx_Bhw).permute(0, 3, 1, 2).contiguous()
96
+ h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
97
+ f_hat = f_hat + h_BChw
98
+ f_rest -= h_BChw
99
+
100
+ if self.training and dist.initialized():
101
+ handler.wait()
102
+ if self.record_hit == 0:
103
+ self.ema_vocab_hit_SV[si].copy_(hit_V)
104
+ elif self.record_hit < 100:
105
+ self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1))
106
+ else:
107
+ self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01))
108
+ self.record_hit += 1
109
+ vocab_hit_V.add_(hit_V)
110
+ mean_vq_loss += F.mse_loss(f_hat.data, f_BChw).mul_(self.beta) + F.mse_loss(f_hat, f_no_grad)
111
+
112
+ mean_vq_loss *= 1. / SN
113
+ f_hat = (f_hat.data - f_no_grad).add_(f_BChw)
114
+
115
+ margin = tdist.get_world_size() * (f_BChw.numel() / f_BChw.shape[1]) / self.vocab_size * 0.08
116
+ # margin = pn*pn / 100
117
+ if ret_usages:
118
+ usages = [(self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100 for si, pn in
119
+ enumerate(self.v_patch_nums)]
120
+ else:
121
+ usages = None
122
+ return f_hat, usages, mean_vq_loss
123
+
124
+ # ===================== `forward` is only used in VAE training =====================
125
+
126
+ def embed_to_fhat(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False) -> Union[
127
+ List[torch.Tensor], torch.Tensor]:
128
+ ls_f_hat_BChw = []
129
+ B = ms_h_BChw[0].shape[0]
130
+ H = W = self.v_patch_nums[-1]
131
+ SN = len(self.v_patch_nums)
132
+ if all_to_max_scale:
133
+ f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32)
134
+ for si, pn in enumerate(self.v_patch_nums): # from small to large
135
+ h_BChw = ms_h_BChw[si]
136
+ if si < len(self.v_patch_nums) - 1:
137
+ h_BChw = F.interpolate(h_BChw, size=(H, W), mode='bilinear')
138
+ h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
139
+ f_hat.add_(h_BChw)
140
+ if last_one:
141
+ ls_f_hat_BChw = f_hat
142
+ else:
143
+ ls_f_hat_BChw.append(f_hat.clone())
144
+ else:
145
+ # WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above)
146
+ # WARNING: this should only be used for experimental purpose
147
+ f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, self.v_patch_nums[0], self.v_patch_nums[0],
148
+ dtype=torch.float32)
149
+ for si, pn in enumerate(self.v_patch_nums): # from small to large
150
+ f_hat = F.interpolate(f_hat, size=(pn, pn), mode='bilinear')
151
+ h_BChw = self.quant_resi[si / (SN - 1)](ms_h_BChw[si])
152
+ f_hat.add_(h_BChw)
153
+ if last_one:
154
+ ls_f_hat_BChw = f_hat
155
+ else:
156
+ ls_f_hat_BChw.append(f_hat)
157
+
158
+ return ls_f_hat_BChw
159
+
160
+ def f_to_idxBl_or_fhat(self, f_BChw: torch.Tensor, to_fhat: bool,
161
+ v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[
162
+ Union[torch.Tensor, torch.LongTensor]]: # z_BChw is the feature from inp_img_no_grad
163
+ B, C, H, W = f_BChw.shape
164
+ f_no_grad = f_BChw.detach()
165
+ f_rest = f_no_grad.clone()
166
+ f_hat = torch.zeros_like(f_rest)
167
+
168
+ f_hat_or_idx_Bl: List[torch.Tensor] = []
169
+
170
+ patch_hws = [(pn, pn) if isinstance(pn, int) else (pn[0], pn[1]) for pn in
171
+ (v_patch_nums or self.v_patch_nums)] # from small to large
172
+ assert patch_hws[-1][0] == H and patch_hws[-1][1] == W, f'{patch_hws[-1]=} != ({H=}, {W=})'
173
+
174
+ SN = len(patch_hws)
175
+ for si, (ph, pw) in enumerate(patch_hws): # from small to large
176
+ if 0 <= self.prog_si < si: break # progressive training: not supported yet, prog_si always -1
177
+ # find the nearest embedding
178
+ z_NC = F.interpolate(f_rest, size=(ph, pw), mode='bilinear').permute(0, 2, 3, 1).reshape(-1, C) if (
179
+ si != SN - 1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
180
+ if self.using_znorm:
181
+ z_NC = F.normalize(z_NC, dim=-1)
182
+ idx_N = torch.argmax(z_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
183
+ else:
184
+ d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum(
185
+ self.embedding.weight.data.square(), dim=1, keepdim=False)
186
+ d_no_grad.addmm_(z_NC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
187
+ idx_N = torch.argmin(d_no_grad, dim=1)
188
+
189
+ idx_Bhw = idx_N.view(B, ph, pw)
190
+ h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W),
191
+ mode='bilinear').contiguous() if (si != SN - 1) else self.embedding(idx_Bhw).permute(
192
+ 0, 3, 1, 2).contiguous()
193
+ h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
194
+ f_hat.add_(h_BChw)
195
+ f_rest.sub_(h_BChw)
196
+ f_hat_or_idx_Bl.append(f_hat.clone() if to_fhat else idx_N.reshape(B, ph * pw))
197
+
198
+ return f_hat_or_idx_Bl
199
+
200
+ # ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input =====================
201
+ def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor:
202
+ next_scales = []
203
+ B = gt_ms_idx_Bl[0].shape[0]
204
+ C = self.Cvae
205
+ H = W = self.v_patch_nums[-1]
206
+ SN = len(self.v_patch_nums)
207
+
208
+ f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32)
209
+ pn_next: int = self.v_patch_nums[0]
210
+ for si in range(SN - 1):
211
+ if self.prog_si == 0 or (
212
+ 0 <= self.prog_si - 1 < si): break # progressive training: not supported yet, prog_si always -1
213
+ h_BChw = F.interpolate(self.embedding(gt_ms_idx_Bl[si]).transpose_(1, 2).view(B, C, pn_next, pn_next),
214
+ size=(H, W), mode='bilinear')
215
+ f_hat.add_(self.quant_resi[si / (SN - 1)](h_BChw))
216
+ pn_next = self.v_patch_nums[si + 1]
217
+ next_scales.append(
218
+ F.interpolate(f_hat, size=(pn_next, pn_next), mode='bilinear').view(B, C, -1).transpose(1, 2))
219
+ return torch.cat(next_scales, dim=1) if len(next_scales) else None # cat BlCs to BLC, this should be float32
220
+
221
+ # ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input =====================
222
+ def get_next_autoregressive_input(self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor) -> Tuple[
223
+ Optional[torch.Tensor], torch.Tensor]: # only used in VAR inference
224
+ HW = self.v_patch_nums[-1]
225
+ if si != SN - 1:
226
+ h = self.quant_resi[si / (SN - 1)](
227
+ F.interpolate(h_BChw, size=(HW, HW), mode='bilinear')) # conv after upsample
228
+ f_hat.add_(h)
229
+ return f_hat, F.interpolate(f_hat, size=(self.v_patch_nums[si + 1], self.v_patch_nums[si + 1]),
230
+ mode='bilinear')
231
+ else:
232
+ h = self.quant_resi[si / (SN - 1)](h_BChw)
233
+ f_hat.add_(h)
234
+ return f_hat, f_hat
235
+
236
+
237
+ class Phi(nn.Conv2d):
238
+ def __init__(self, embed_dim, quant_resi):
239
+ ks = 3
240
+ super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2)
241
+ self.resi_ratio = abs(quant_resi)
242
+
243
+ def forward(self, h_BChw):
244
+ return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
245
+
246
+
247
+ class PhiShared(nn.Module):
248
+ def __init__(self, qresi: Phi):
249
+ super().__init__()
250
+ self.qresi: Phi = qresi
251
+
252
+ def __getitem__(self, _) -> Phi:
253
+ return self.qresi
254
+
255
+
256
+ class PhiPartiallyShared(nn.Module):
257
+ def __init__(self, qresi_ls: nn.ModuleList):
258
+ super().__init__()
259
+ self.qresi_ls = qresi_ls
260
+ K = len(qresi_ls)
261
+ self.ticks = np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) if K == 4 else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
262
+
263
+ def __getitem__(self, at_from_0_to_1: float) -> Phi:
264
+ return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()]
265
+
266
+ def extra_repr(self) -> str:
267
+ return f'ticks={self.ticks}'
268
+
269
+
270
+ class PhiNonShared(nn.ModuleList):
271
+ def __init__(self, qresi: List):
272
+ super().__init__(qresi)
273
+ # self.qresi = qresi
274
+ K = len(qresi)
275
+ self.ticks = np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) if K == 4 else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
276
+
277
+ def __getitem__(self, at_from_0_to_1: float) -> Phi:
278
+ return super().__getitem__(np.argmin(np.abs(self.ticks - at_from_0_to_1)).item())
279
+
280
+ def extra_repr(self) -> str:
281
+ return f'ticks={self.ticks}'
VAR/models/var.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import tqdm
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+
10
+ import dist
11
+ from models.basic_var import AdaLNBeforeHead, AdaLNSelfAttn
12
+ from models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_
13
+ from models.vqvae import VQVAE, VectorQuantizer2
14
+ import lovely_tensors as lt
15
+ lt.monkey_patch()
16
+
17
+ class SharedAdaLin(nn.Linear):
18
+ def forward(self, cond_BD):
19
+ C = self.weight.shape[0] // 6
20
+ return super().forward(cond_BD).view(-1, 1, 6, C) # B16C
21
+
22
+
23
+ class VAR(nn.Module):
24
+ def __init__(
25
+ self, vae_local: VQVAE,
26
+ num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
27
+ norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
28
+ attn_l2_norm=False,
29
+ patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
30
+ flash_if_available=True, fused_if_available=True,
31
+ ):
32
+ super().__init__()
33
+ # 0. hyperparameters
34
+ assert embed_dim % num_heads == 0
35
+ self.Cvae, self.V = vae_local.Cvae, vae_local.vocab_size
36
+ self.depth, self.C, self.D, self.num_heads = depth, embed_dim, embed_dim, num_heads
37
+
38
+ self.cond_drop_rate = cond_drop_rate
39
+ self.prog_si = -1 # progressive training
40
+
41
+ self.patch_nums: Tuple[int] = patch_nums
42
+ self.L = sum(pn ** 2 for pn in self.patch_nums)
43
+ self.first_l = self.patch_nums[0] ** 2
44
+ self.begin_ends = []
45
+ cur = 0
46
+ for i, pn in enumerate(self.patch_nums):
47
+ self.begin_ends.append((cur, cur+pn ** 2))
48
+ cur += pn ** 2
49
+
50
+ self.num_stages_minus_1 = len(self.patch_nums) - 1
51
+ self.rng = torch.Generator(device="mps")
52
+
53
+ # 1. input (word) embedding
54
+ quant: VectorQuantizer2 = vae_local.quantize
55
+ self.vae_proxy: Tuple[VQVAE] = (vae_local,)
56
+ self.vae_quant_proxy: Tuple[VectorQuantizer2] = (quant,)
57
+ self.word_embed = nn.Linear(self.Cvae, self.C)
58
+
59
+ # 2. class embedding
60
+ init_std = math.sqrt(1 / self.C / 3)
61
+ self.num_classes = num_classes
62
+ self.uniform_prob = torch.full((1, num_classes), fill_value=1.0 / num_classes, dtype=torch.float32, device=dist.get_device())
63
+ self.class_emb = nn.Embedding(self.num_classes + 1, self.C)
64
+ nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std)
65
+ self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
66
+ nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
67
+
68
+ # 3. absolute position embedding
69
+ pos_1LC = []
70
+ for i, pn in enumerate(self.patch_nums):
71
+ pe = torch.empty(1, pn*pn, self.C)
72
+ nn.init.trunc_normal_(pe, mean=0, std=init_std)
73
+ pos_1LC.append(pe)
74
+ pos_1LC = torch.cat(pos_1LC, dim=1) # 1, L, C
75
+ assert tuple(pos_1LC.shape) == (1, self.L, self.C)
76
+ self.pos_1LC = nn.Parameter(pos_1LC)
77
+ # level embedding (similar to GPT's segment embedding, used to distinguish different levels of token pyramid)
78
+ self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C)
79
+ nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
80
+
81
+ # 4. backbone blocks
82
+ self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity()
83
+
84
+ norm_layer = partial(nn.LayerNorm, eps=norm_eps)
85
+ self.drop_path_rate = drop_path_rate
86
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule (linearly increasing)
87
+ self.blocks = nn.ModuleList([
88
+ AdaLNSelfAttn(
89
+ cond_dim=self.D, shared_aln=shared_aln,
90
+ block_idx=block_idx, embed_dim=self.C, norm_layer=norm_layer, num_heads=num_heads, mlp_ratio=mlp_ratio,
91
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[block_idx], last_drop_p=0 if block_idx == 0 else dpr[block_idx-1],
92
+ attn_l2_norm=attn_l2_norm,
93
+ flash_if_available=flash_if_available, fused_if_available=fused_if_available,
94
+ )
95
+ for block_idx in range(depth)
96
+ ])
97
+
98
+ fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks]
99
+ self.using_fused_add_norm_fn = any(fused_add_norm_fns)
100
+ print(
101
+ f'\n[constructor] ==== flash_if_available={flash_if_available} ({sum(b.attn.using_flash for b in self.blocks)}/{self.depth}), fused_if_available={fused_if_available} (fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \n'
102
+ f' [VAR config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}\n'
103
+ f' [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})',
104
+ end='\n\n', flush=True
105
+ )
106
+
107
+ # 5. attention mask used in training (for masking out the future)
108
+ # it won't be used in inference, since kv cache is enabled
109
+ d: torch.Tensor = torch.cat([torch.full((pn*pn,), i) for i, pn in enumerate(self.patch_nums)]).view(1, self.L, 1)
110
+ dT = d.transpose(1, 2) # dT: 11L
111
+ lvl_1L = dT[:, 0].contiguous()
112
+ self.register_buffer('lvl_1L', lvl_1L)
113
+ attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, self.L, self.L)
114
+ self.register_buffer('attn_bias_for_masking', attn_bias_for_masking.contiguous())
115
+
116
+ # 6. classifier head
117
+ self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer)
118
+ self.head = nn.Linear(self.C, self.V)
119
+
120
+ def get_logits(self, h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], cond_BD: Optional[torch.Tensor]):
121
+ if not isinstance(h_or_h_and_residual, torch.Tensor):
122
+ h, resi = h_or_h_and_residual # fused_add_norm must be used
123
+ h = resi + self.blocks[-1].drop_path(h)
124
+ else: # fused_add_norm is not used
125
+ h = h_or_h_and_residual
126
+ return self.head(self.head_nm(h.float(), cond_BD).float()).float()
127
+
128
+ @torch.no_grad()
129
+ def autoregressive_infer_cfg(
130
+ self, B: int, label_B: Optional[Union[int, torch.LongTensor]],cond_delta: torch.Tensor =None,
131
+ g_seed: Optional[int] = None, cfg=1.5, top_k=0, top_p=0.0,beta=0,alpha = 1,
132
+ more_smooth=False,
133
+
134
+ ) -> torch.Tensor: # returns reconstructed image (B, 3, H, W) in [0, 1]
135
+ """
136
+ only used for inference, on autoregressive mode
137
+ :param B: batch size
138
+ :param label_B: imagenet label; if None, randomly sampled
139
+ :param g_seed: random seed
140
+ :param cfg: classifier-free guidance ratio
141
+ :param top_k: top-k sampling
142
+ :param top_p: top-p sampling
143
+ :param more_smooth: smoothing the pred using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
144
+ :return: if returns_vemb: list of embedding h_BChw := vae_embed(idx_Bl), else: list of idx_Bl
145
+ """
146
+ if g_seed is None: rng = None
147
+ else: self.rng.manual_seed(g_seed); rng = self.rng
148
+
149
+ if label_B is None:
150
+ label_B = torch.multinomial(self.uniform_prob, num_samples=B, replacement=True, generator=rng).reshape(B)
151
+ elif isinstance(label_B, int):
152
+ label_B = torch.full((B,), fill_value=self.num_classes if label_B < 0 else label_B, device=self.lvl_1L.device)
153
+
154
+ if alpha == 1:
155
+ sos = cond_BD = self.class_emb(torch.cat((label_B, torch.full_like(label_B, fill_value=self.num_classes))))
156
+ else:
157
+ sos = cond_BD = self.class_emb(torch.cat((label_B, torch.full_like(label_B, fill_value=self.num_classes))))
158
+ sos[0] = cond_BD[0] = cond_BD[0]*alpha
159
+ if cond_delta is not None and beta != 0:
160
+ cond_BD[0] += cond_delta[0] * beta
161
+ sos[0] += sos[0] * beta
162
+
163
+ lvl_pos = self.lvl_embed(self.lvl_1L) + self.pos_1LC
164
+ next_token_map = sos.unsqueeze(1).expand(2 * B, self.first_l, -1) + self.pos_start.expand(2 * B, self.first_l, -1) + lvl_pos[:, :self.first_l]
165
+
166
+ cur_L = 0
167
+ f_hat = sos.new_zeros(B, self.Cvae, self.patch_nums[-1], self.patch_nums[-1])
168
+ for b in self.blocks: b.attn.kv_caching(True)
169
+ for si, pn in enumerate(self.patch_nums): # si: i-th segment
170
+ ratio = si / self.num_stages_minus_1
171
+ # last_L = cur_L
172
+ cur_L += pn*pn
173
+ # assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item'
174
+ cond_BD_or_gss = self.shared_ada_lin(cond_BD)
175
+ x = next_token_map
176
+ AdaLNSelfAttn.forward
177
+ for b in self.blocks:
178
+ x = b(x=x, cond_BD=cond_BD_or_gss, attn_bias=None)
179
+ logits_BlV = self.get_logits(x, cond_BD)
180
+ t = cfg * ratio
181
+ logits_BlV = (1+t) * logits_BlV[:B] - t * logits_BlV[B:]
182
+
183
+ idx_Bl = sample_with_top_k_top_p_(logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1)[:, :, 0]
184
+ if not more_smooth: # this is the default case
185
+ h_BChw = self.vae_quant_proxy[0].embedding(idx_Bl) # B, l, Cvae
186
+ else: # not used when evaluating FID/IS/Precision/Recall
187
+ gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git
188
+ h_BChw = gumbel_softmax_with_rng(logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng) @ self.vae_quant_proxy[0].embedding.weight.unsqueeze(0)
189
+
190
+ h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.Cvae, pn, pn)
191
+ f_hat, next_token_map = self.vae_quant_proxy[0].get_next_autoregressive_input(si, len(self.patch_nums), f_hat, h_BChw)
192
+ if si != self.num_stages_minus_1: # prepare for next stage
193
+ next_token_map = next_token_map.view(B, self.Cvae, -1).transpose(1, 2)
194
+ next_token_map = self.word_embed(next_token_map) + lvl_pos[:, cur_L:cur_L + self.patch_nums[si+1] ** 2]
195
+ next_token_map = next_token_map.repeat(2, 1, 1) # double the batch sizes due to CFG
196
+
197
+ for b in self.blocks: b.attn.kv_caching(False)
198
+ return self.vae_proxy[0].fhat_to_img(f_hat).add_(1).mul_(0.5) # de-normalize, from [-1, 1] to [0, 1]
199
+
200
+ def forward(self, label_B: torch.LongTensor, x_BLCv_wo_first_l: torch.Tensor,cond_delta: torch.Tensor = None,beta=0.05,alpha = 1) -> torch.Tensor: # returns logits_BLV
201
+ """
202
+ :param label_B: label_B
203
+ :param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae)
204
+ :return: logits BLV, V is vocab_size
205
+ """
206
+ bg, ed = self.begin_ends[self.prog_si] if self.prog_si >= 0 else (0, self.L)
207
+ B = x_BLCv_wo_first_l.shape[0]
208
+ with torch.cuda.amp.autocast(enabled=False):
209
+ label_B = torch.where(torch.rand(B, device=label_B.device) < self.cond_drop_rate, self.num_classes, label_B)
210
+ if cond_delta is not None:
211
+ sos = cond_BD = alpha * self.class_emb(label_B) + beta * cond_delta
212
+ else:
213
+ sos = cond_BD = self.class_emb(label_B)
214
+ sos = sos.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(B, self.first_l, -1)
215
+
216
+ if self.prog_si == 0: x_BLC = sos
217
+ else: x_BLC = torch.cat((sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1)
218
+ x_BLC += self.lvl_embed(self.lvl_1L[:, :ed].expand(B, -1)) + self.pos_1LC[:, :ed] # lvl: BLC; pos: 1LC
219
+
220
+ attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
221
+ cond_BD_or_gss = self.shared_ada_lin(cond_BD)
222
+
223
+ # hack: get the dtype if mixed precision is used
224
+ temp = x_BLC.new_ones(8, 8)
225
+ main_type = torch.matmul(temp, temp).dtype
226
+
227
+ x_BLC = x_BLC.to(dtype=main_type)
228
+ cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)
229
+ attn_bias = attn_bias.to(dtype=main_type)
230
+
231
+ AdaLNSelfAttn.forward
232
+ for i, b in enumerate(self.blocks):
233
+ x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=attn_bias)
234
+ x_BLC = self.get_logits(x_BLC.float(), cond_BD)
235
+
236
+ if self.prog_si == 0:
237
+ if isinstance(self.word_embed, nn.Linear):
238
+ x_BLC[0, 0, 0] += self.word_embed.weight[0, 0] * 0 + self.word_embed.bias[0] * 0
239
+ else:
240
+ s = 0
241
+ for p in self.word_embed.parameters():
242
+ if p.requires_grad:
243
+ s += p.view(-1)[0] * 0
244
+ x_BLC[0, 0, 0] += s
245
+ return x_BLC # logits BLV, V is vocab_size
246
+
247
+ def init_weights(self, init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=0.02, conv_std_or_gain=0.02):
248
+ if init_std < 0: init_std = (1 / self.C / 3) ** 0.5 # init_std < 0: automated
249
+
250
+ print(f'[init_weights] {type(self).__name__} with {init_std=:g}')
251
+ for m in self.modules():
252
+ with_weight = hasattr(m, 'weight') and m.weight is not None
253
+ with_bias = hasattr(m, 'bias') and m.bias is not None
254
+ if isinstance(m, nn.Linear):
255
+ nn.init.trunc_normal_(m.weight.data, std=init_std)
256
+ if with_bias: m.bias.data.zero_()
257
+ elif isinstance(m, nn.Embedding):
258
+ nn.init.trunc_normal_(m.weight.data, std=init_std)
259
+ if m.padding_idx is not None: m.weight.data[m.padding_idx].zero_()
260
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
261
+ if with_weight: m.weight.data.fill_(1.)
262
+ if with_bias: m.bias.data.zero_()
263
+ # conv: VAR has no conv, only VQVAE has conv
264
+ elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
265
+ if conv_std_or_gain > 0: nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain)
266
+ else: nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain)
267
+ if with_bias: m.bias.data.zero_()
268
+
269
+ if init_head >= 0:
270
+ if isinstance(self.head, nn.Linear):
271
+ self.head.weight.data.mul_(init_head)
272
+ self.head.bias.data.zero_()
273
+ elif isinstance(self.head, nn.Sequential):
274
+ self.head[-1].weight.data.mul_(init_head)
275
+ self.head[-1].bias.data.zero_()
276
+
277
+ if isinstance(self.head_nm, AdaLNBeforeHead):
278
+ self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln)
279
+ if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None:
280
+ self.head_nm.ada_lin[-1].bias.data.zero_()
281
+
282
+ depth = len(self.blocks)
283
+ for block_idx, sab in enumerate(self.blocks):
284
+ sab: AdaLNSelfAttn
285
+ sab.attn.proj.weight.data.div_(math.sqrt(2 * depth))
286
+ sab.ffn.fc2.weight.data.div_(math.sqrt(2 * depth))
287
+ if hasattr(sab.ffn, 'fcg') and sab.ffn.fcg is not None:
288
+ nn.init.ones_(sab.ffn.fcg.bias)
289
+ nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5)
290
+ if hasattr(sab, 'ada_lin'):
291
+ sab.ada_lin[-1].weight.data[2*self.C:].mul_(init_adaln)
292
+ sab.ada_lin[-1].weight.data[:2*self.C].mul_(init_adaln_gamma)
293
+ if hasattr(sab.ada_lin[-1], 'bias') and sab.ada_lin[-1].bias is not None:
294
+ sab.ada_lin[-1].bias.data.zero_()
295
+ elif hasattr(sab, 'ada_gss'):
296
+ sab.ada_gss.data[:, :, 2:].mul_(init_adaln)
297
+ sab.ada_gss.data[:, :, :2].mul_(init_adaln_gamma)
298
+
299
+ def extra_repr(self):
300
+ return f'drop_path_rate={self.drop_path_rate:g}'
301
+
302
+
303
+ class VARHF(VAR, PyTorchModelHubMixin):
304
+ # repo_url="https://github.com/FoundationVision/VAR",
305
+ # tags=["image-generation"]):
306
+ def __init__(
307
+ self,
308
+ vae_kwargs,
309
+ num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
310
+ norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
311
+ attn_l2_norm=False,
312
+ patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
313
+ flash_if_available=True, fused_if_available=True,
314
+ ):
315
+ vae_local = VQVAE(**vae_kwargs)
316
+ super().__init__(
317
+ vae_local=vae_local,
318
+ num_classes=num_classes, depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
319
+ norm_eps=norm_eps, shared_aln=shared_aln, cond_drop_rate=cond_drop_rate,
320
+ attn_l2_norm=attn_l2_norm,
321
+ patch_nums=patch_nums,
322
+ flash_if_available=flash_if_available, fused_if_available=fused_if_available,
323
+ )
VAR/models/vqvae.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ References:
3
+ - VectorQuantizer2: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L110
4
+ - GumbelQuantize: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L213
5
+ - VQVAE (VQModel): https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/autoencoder.py#L14
6
+ """
7
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from .basic_vae import Decoder, Encoder
13
+ from .quant import VectorQuantizer2
14
+
15
+
16
+ class VQVAE(nn.Module):
17
+ def __init__(
18
+ self, vocab_size=4096, z_channels=32, ch=128, dropout=0.0,
19
+ beta=0.25, # commitment loss weight
20
+ using_znorm=False, # whether to normalize when computing the nearest neighbors
21
+ quant_conv_ks=3, # quant conv kernel size
22
+ quant_resi=0.5, # 0.5 means \phi(x) = 0.5conv(x) + (1-0.5)x
23
+ share_quant_resi=4, # use 4 \phi layers for K scales: partially-shared \phi
24
+ default_qresi_counts=0, # if is 0: automatically set to len(v_patch_nums)
25
+ v_patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # number of patches for each scale, h_{1 to K} = w_{1 to K} = v_patch_nums[k]
26
+ test_mode=True,
27
+ ):
28
+ super().__init__()
29
+ self.test_mode = test_mode
30
+ self.V, self.Cvae = vocab_size, z_channels
31
+ # ddconfig is copied from https://github.com/CompVis/latent-diffusion/blob/e66308c7f2e64cb581c6d27ab6fbeb846828253b/models/first_stage_models/vq-f16/config.yaml
32
+ ddconfig = dict(
33
+ dropout=dropout, ch=ch, z_channels=z_channels,
34
+ in_channels=3, ch_mult=(1, 1, 2, 2, 4), num_res_blocks=2, # from vq-f16/config.yaml above
35
+ using_sa=True, using_mid_sa=True, # from vq-f16/config.yaml above
36
+ # resamp_with_conv=True, # always True, removed.
37
+ )
38
+ ddconfig.pop('double_z', None) # only KL-VAE should use double_z=True
39
+ self.encoder = Encoder(double_z=False, **ddconfig)
40
+ self.decoder = Decoder(**ddconfig)
41
+
42
+ self.vocab_size = vocab_size
43
+ self.downsample = 2 ** (len(ddconfig['ch_mult'])-1)
44
+ self.quantize: VectorQuantizer2 = VectorQuantizer2(
45
+ vocab_size=vocab_size, Cvae=self.Cvae, using_znorm=using_znorm, beta=beta,
46
+ default_qresi_counts=default_qresi_counts, v_patch_nums=v_patch_nums, quant_resi=quant_resi, share_quant_resi=share_quant_resi,
47
+ )
48
+ self.quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)
49
+ self.post_quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)
50
+
51
+ if self.test_mode:
52
+ self.eval()
53
+ [p.requires_grad_(False) for p in self.parameters()]
54
+
55
+ # ===================== `forward` is only used in VAE training =====================
56
+ def forward(self, inp, ret_usages=False): # -> rec_B3HW, idx_N, loss
57
+ VectorQuantizer2.forward
58
+ f_hat, usages, vq_loss = self.quantize(self.quant_conv(self.encoder(inp)), ret_usages=ret_usages)
59
+ return self.decoder(self.post_quant_conv(f_hat)), usages, vq_loss
60
+ # ===================== `forward` is only used in VAE training =====================
61
+
62
+ def fhat_to_img(self, f_hat: torch.Tensor):
63
+ return self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)
64
+
65
+ def img_to_idxBl(self, inp_img_no_grad: torch.Tensor, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[torch.LongTensor]: # return List[Bl]
66
+ f = self.quant_conv(self.encoder(inp_img_no_grad))
67
+ return self.quantize.f_to_idxBl_or_fhat(f, to_fhat=False, v_patch_nums=v_patch_nums)
68
+
69
+ def idxBl_to_img(self, ms_idx_Bl: List[torch.Tensor], same_shape: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
70
+ B = ms_idx_Bl[0].shape[0]
71
+ ms_h_BChw = []
72
+ for idx_Bl in ms_idx_Bl:
73
+ l = idx_Bl.shape[1]
74
+ pn = round(l ** 0.5)
75
+ ms_h_BChw.append(self.quantize.embedding(idx_Bl).transpose(1, 2).view(B, self.Cvae, pn, pn))
76
+ return self.embed_to_img(ms_h_BChw=ms_h_BChw, all_to_max_scale=same_shape, last_one=last_one)
77
+
78
+ def embed_to_img(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
79
+ if last_one:
80
+ return self.decoder(self.post_quant_conv(self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=True))).clamp_(-1, 1)
81
+ else:
82
+ return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=False)]
83
+
84
+ def img_to_reconstructed_img(self, x, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None, last_one=False) -> List[torch.Tensor]:
85
+ f = self.quant_conv(self.encoder(x))
86
+ ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat(f, to_fhat=True, v_patch_nums=v_patch_nums)
87
+ if last_one:
88
+ return self.decoder(self.post_quant_conv(ls_f_hat_BChw[-1])).clamp_(-1, 1)
89
+ else:
90
+ return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in ls_f_hat_BChw]
91
+
92
+ def load_state_dict(self, state_dict: Dict[str, Any], strict=True, assign=False):
93
+ if 'quantize.ema_vocab_hit_SV' in state_dict and state_dict['quantize.ema_vocab_hit_SV'].shape[0] != self.quantize.ema_vocab_hit_SV.shape[0]:
94
+ state_dict['quantize.ema_vocab_hit_SV'] = self.quantize.ema_vocab_hit_SV
95
+ return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
VAR/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch~=2.1.0
2
+
3
+ Pillow
4
+ huggingface_hub
5
+ numpy
6
+ pytz
7
+ transformers
8
+ typed-argument-parser
VAR/train.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import shutil
4
+ import sys
5
+ import time
6
+ import warnings
7
+ from functools import partial
8
+
9
+ import torch
10
+ from torch.utils.data import DataLoader
11
+
12
+ import dist
13
+ from utils import arg_util, misc
14
+ from utils.data import build_dataset
15
+ from utils.data_sampler import DistInfiniteBatchSampler, EvalDistributedSampler
16
+ from utils.misc import auto_resume
17
+
18
+
19
+ def build_everything(args: arg_util.Args):
20
+ # resume
21
+ auto_resume_info, start_ep, start_it, trainer_state, args_state = auto_resume(args, 'ar-ckpt*.pth')
22
+ # create tensorboard logger
23
+ tb_lg: misc.TensorboardLogger
24
+ with_tb_lg = dist.is_master()
25
+ if with_tb_lg:
26
+ os.makedirs(args.tb_log_dir_path, exist_ok=True)
27
+ # noinspection PyTypeChecker
28
+ tb_lg = misc.DistLogger(misc.TensorboardLogger(log_dir=args.tb_log_dir_path, filename_suffix=f'__{misc.time_str("%m%d_%H%M")}'), verbose=True)
29
+ tb_lg.flush()
30
+ else:
31
+ # noinspection PyTypeChecker
32
+ tb_lg = misc.DistLogger(None, verbose=False)
33
+ dist.barrier()
34
+
35
+ # log args
36
+ print(f'global bs={args.glb_batch_size}, local bs={args.batch_size}')
37
+ print(f'initial args:\n{str(args)}')
38
+
39
+ # build data
40
+ if not args.local_debug:
41
+ print(f'[build PT data] ...\n')
42
+ num_classes, dataset_train, dataset_val = build_dataset(
43
+ args.data_path, final_reso=args.data_load_reso, hflip=args.hflip, mid_reso=args.mid_reso,
44
+ )
45
+ types = str((type(dataset_train).__name__, type(dataset_val).__name__))
46
+
47
+ ld_val = DataLoader(
48
+ dataset_val, num_workers=0, pin_memory=True,
49
+ batch_size=round(args.batch_size*1.5), sampler=EvalDistributedSampler(dataset_val, num_replicas=dist.get_world_size(), rank=dist.get_rank()),
50
+ shuffle=False, drop_last=False,
51
+ )
52
+ del dataset_val
53
+
54
+ ld_train = DataLoader(
55
+ dataset=dataset_train, num_workers=args.workers, pin_memory=True,
56
+ generator=args.get_different_generator_for_each_rank(), # worker_init_fn=worker_init_fn,
57
+ batch_sampler=DistInfiniteBatchSampler(
58
+ dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size, same_seed_for_all_ranks=args.same_seed_for_all_ranks,
59
+ shuffle=True, fill_last=True, rank=dist.get_rank(), world_size=dist.get_world_size(), start_ep=start_ep, start_it=start_it,
60
+ ),
61
+ )
62
+ del dataset_train
63
+
64
+ [print(line) for line in auto_resume_info]
65
+ print(f'[dataloader multi processing] ...', end='', flush=True)
66
+ stt = time.time()
67
+ iters_train = len(ld_train)
68
+ ld_train = iter(ld_train)
69
+ # noinspection PyArgumentList
70
+ print(f' [dataloader multi processing](*) finished! ({time.time()-stt:.2f}s)', flush=True, clean=True)
71
+ print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size}, iters_train={iters_train}, types(tr, va)={types}')
72
+
73
+ else:
74
+ num_classes = 1000
75
+ ld_val = ld_train = None
76
+ iters_train = 10
77
+
78
+ # build models
79
+ from torch.nn.parallel import DistributedDataParallel as DDP
80
+ from models import VAR, VQVAE, build_vae_var
81
+ from trainer import VARTrainer
82
+ from utils.amp_sc import AmpOptimizer
83
+ from utils.lr_control import filter_params
84
+
85
+ vae_local, var_wo_ddp = build_vae_var(
86
+ V=4096, Cvae=32, ch=160, share_quant_resi=4, # hard-coded VQVAE hyperparameters
87
+ device=dist.get_device(), patch_nums=args.patch_nums,
88
+ num_classes=num_classes, depth=args.depth, shared_aln=args.saln, attn_l2_norm=args.anorm,
89
+ flash_if_available=args.fuse, fused_if_available=args.fuse,
90
+ init_adaln=args.aln, init_adaln_gamma=args.alng, init_head=args.hd, init_std=args.ini,
91
+ )
92
+
93
+ vae_ckpt = 'vae_ch160v4096z32.pth'
94
+ if dist.is_local_master():
95
+ if not os.path.exists(vae_ckpt):
96
+ os.system(f'wget https://huggingface.co/FoundationVision/var/resolve/main/{vae_ckpt}')
97
+ dist.barrier()
98
+ vae_local.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
99
+
100
+ vae_local: VQVAE = args.compile_model(vae_local, args.vfast)
101
+ var_wo_ddp: VAR = args.compile_model(var_wo_ddp, args.tfast)
102
+ var: DDP = (DDP if dist.initialized() else NullDDP)(var_wo_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
103
+
104
+ print(f'[INIT] VAR model = {var_wo_ddp}\n\n')
105
+ count_p = lambda m: f'{sum(p.numel() for p in m.parameters())/1e6:.2f}'
106
+ print(f'[INIT][#para] ' + ', '.join([f'{k}={count_p(m)}' for k, m in (('VAE', vae_local), ('VAE.enc', vae_local.encoder), ('VAE.dec', vae_local.decoder), ('VAE.quant', vae_local.quantize))]))
107
+ print(f'[INIT][#para] ' + ', '.join([f'{k}={count_p(m)}' for k, m in (('VAR', var_wo_ddp),)]) + '\n\n')
108
+
109
+ # build optimizer
110
+ names, paras, para_groups = filter_params(var_wo_ddp, nowd_keys={
111
+ 'cls_token', 'start_token', 'task_token', 'cfg_uncond',
112
+ 'pos_embed', 'pos_1LC', 'pos_start', 'start_pos', 'lvl_embed',
113
+ 'gamma', 'beta',
114
+ 'ada_gss', 'moe_bias',
115
+ 'scale_mul',
116
+ })
117
+ opt_clz = {
118
+ 'adam': partial(torch.optim.AdamW, betas=(0.9, 0.95), fused=args.afuse),
119
+ 'adamw': partial(torch.optim.AdamW, betas=(0.9, 0.95), fused=args.afuse),
120
+ }[args.opt.lower().strip()]
121
+ opt_kw = dict(lr=args.tlr, weight_decay=0)
122
+ print(f'[INIT] optim={opt_clz}, opt_kw={opt_kw}\n')
123
+
124
+ var_optim = AmpOptimizer(
125
+ mixed_precision=args.fp16, optimizer=opt_clz(params=para_groups, **opt_kw), names=names, paras=paras,
126
+ grad_clip=args.tclip, n_gradient_accumulation=args.ac
127
+ )
128
+ del names, paras, para_groups
129
+
130
+ # build trainer
131
+ trainer = VARTrainer(
132
+ device=args.device, patch_nums=args.patch_nums, resos=args.resos,
133
+ vae_local=vae_local, var_wo_ddp=var_wo_ddp, var=var,
134
+ var_opt=var_optim, label_smooth=args.ls,
135
+ )
136
+ if trainer_state is not None and len(trainer_state):
137
+ trainer.load_state_dict(trainer_state, strict=False, skip_vae=True) # don't load vae again
138
+ del vae_local, var_wo_ddp, var, var_optim
139
+
140
+ if args.local_debug:
141
+ rng = torch.Generator('cpu')
142
+ rng.manual_seed(0)
143
+ B = 4
144
+ inp = torch.rand(B, 3, args.data_load_reso, args.data_load_reso)
145
+ label = torch.ones(B, dtype=torch.long)
146
+
147
+ me = misc.MetricLogger(delimiter=' ')
148
+ trainer.train_step(
149
+ it=0, g_it=0, stepping=True, metric_lg=me, tb_lg=tb_lg,
150
+ inp_B3HW=inp, label_B=label, prog_si=args.pg0, prog_wp_it=20,
151
+ )
152
+ trainer.load_state_dict(trainer.state_dict())
153
+ trainer.train_step(
154
+ it=99, g_it=599, stepping=True, metric_lg=me, tb_lg=tb_lg,
155
+ inp_B3HW=inp, label_B=label, prog_si=-1, prog_wp_it=20,
156
+ )
157
+ print({k: meter.global_avg for k, meter in me.meters.items()})
158
+
159
+ args.dump_log(); tb_lg.flush(); tb_lg.close()
160
+ if isinstance(sys.stdout, misc.SyncPrint) and isinstance(sys.stderr, misc.SyncPrint):
161
+ sys.stdout.close(), sys.stderr.close()
162
+ exit(0)
163
+
164
+ dist.barrier()
165
+ return (
166
+ tb_lg, trainer, start_ep, start_it,
167
+ iters_train, ld_train, ld_val
168
+ )
169
+
170
+
171
+ def main_training():
172
+ args: arg_util.Args = arg_util.init_dist_and_get_args()
173
+ if args.local_debug:
174
+ torch.autograd.set_detect_anomaly(True)
175
+
176
+ (
177
+ tb_lg, trainer,
178
+ start_ep, start_it,
179
+ iters_train, ld_train, ld_val
180
+ ) = build_everything(args)
181
+
182
+ # train
183
+ start_time = time.time()
184
+ best_L_mean, best_L_tail, best_acc_mean, best_acc_tail = 999., 999., -1., -1.
185
+ best_val_loss_mean, best_val_loss_tail, best_val_acc_mean, best_val_acc_tail = 999, 999, -1, -1
186
+
187
+ L_mean, L_tail = -1, -1
188
+ for ep in range(start_ep, args.ep):
189
+ if hasattr(ld_train, 'sampler') and hasattr(ld_train.sampler, 'set_epoch'):
190
+ ld_train.sampler.set_epoch(ep)
191
+ if ep < 3:
192
+ # noinspection PyArgumentList
193
+ print(f'[{type(ld_train).__name__}] [ld_train.sampler.set_epoch({ep})]', flush=True, force=True)
194
+ tb_lg.set_step(ep * iters_train)
195
+
196
+ stats, (sec, remain_time, finish_time) = train_one_ep(
197
+ ep, ep == start_ep, start_it if ep == start_ep else 0, args, tb_lg, ld_train, iters_train, trainer
198
+ )
199
+
200
+ L_mean, L_tail, acc_mean, acc_tail, grad_norm = stats['Lm'], stats['Lt'], stats['Accm'], stats['Acct'], stats['tnm']
201
+ best_L_mean, best_acc_mean = min(best_L_mean, L_mean), max(best_acc_mean, acc_mean)
202
+ if L_tail != -1: best_L_tail, best_acc_tail = min(best_L_tail, L_tail), max(best_acc_tail, acc_tail)
203
+ args.L_mean, args.L_tail, args.acc_mean, args.acc_tail, args.grad_norm = L_mean, L_tail, acc_mean, acc_tail, grad_norm
204
+ args.cur_ep = f'{ep+1}/{args.ep}'
205
+ args.remain_time, args.finish_time = remain_time, finish_time
206
+
207
+ AR_ep_loss = dict(L_mean=L_mean, L_tail=L_tail, acc_mean=acc_mean, acc_tail=acc_tail)
208
+ is_val_and_also_saving = (ep + 1) % 10 == 0 or (ep + 1) == args.ep
209
+ if is_val_and_also_saving:
210
+ val_loss_mean, val_loss_tail, val_acc_mean, val_acc_tail, tot, cost = trainer.eval_ep(ld_val)
211
+ best_updated = best_val_loss_tail > val_loss_tail
212
+ best_val_loss_mean, best_val_loss_tail = min(best_val_loss_mean, val_loss_mean), min(best_val_loss_tail, val_loss_tail)
213
+ best_val_acc_mean, best_val_acc_tail = max(best_val_acc_mean, val_acc_mean), max(best_val_acc_tail, val_acc_tail)
214
+ AR_ep_loss.update(vL_mean=val_loss_mean, vL_tail=val_loss_tail, vacc_mean=val_acc_mean, vacc_tail=val_acc_tail)
215
+ args.vL_mean, args.vL_tail, args.vacc_mean, args.vacc_tail = val_loss_mean, val_loss_tail, val_acc_mean, val_acc_tail
216
+ print(f' [*] [ep{ep}] (val {tot}) Lm: {L_mean:.4f}, Lt: {L_tail:.4f}, Acc m&t: {acc_mean:.2f} {acc_tail:.2f}, Val cost: {cost:.2f}s')
217
+
218
+ if dist.is_local_master():
219
+ local_out_ckpt = os.path.join(args.local_out_dir_path, 'ar-ckpt-last.pth')
220
+ local_out_ckpt_best = os.path.join(args.local_out_dir_path, 'ar-ckpt-best.pth')
221
+ print(f'[saving ckpt] ...', end='', flush=True)
222
+ torch.save({
223
+ 'epoch': ep+1,
224
+ 'iter': 0,
225
+ 'trainer': trainer.state_dict(),
226
+ 'args': args.state_dict(),
227
+ }, local_out_ckpt)
228
+ if best_updated:
229
+ shutil.copy(local_out_ckpt, local_out_ckpt_best)
230
+ print(f' [saving ckpt](*) finished! @ {local_out_ckpt}', flush=True, clean=True)
231
+ dist.barrier()
232
+
233
+ print( f' [ep{ep}] (training ) Lm: {best_L_mean:.3f} ({L_mean:.3f}), Lt: {best_L_tail:.3f} ({L_tail:.3f}), Acc m&t: {best_acc_mean:.2f} {best_acc_tail:.2f}, Remain: {remain_time}, Finish: {finish_time}', flush=True)
234
+ tb_lg.update(head='AR_ep_loss', step=ep+1, **AR_ep_loss)
235
+ tb_lg.update(head='AR_z_burnout', step=ep+1, rest_hours=round(sec / 60 / 60, 2))
236
+ args.dump_log(); tb_lg.flush()
237
+
238
+ total_time = f'{(time.time() - start_time) / 60 / 60:.1f}h'
239
+ print('\n\n')
240
+ print(f' [*] [PT finished] Total cost: {total_time}, Lm: {best_L_mean:.3f} ({L_mean}), Lt: {best_L_tail:.3f} ({L_tail})')
241
+ print('\n\n')
242
+
243
+ del stats
244
+ del iters_train, ld_train
245
+ time.sleep(3), gc.collect(), torch.cuda.empty_cache(), time.sleep(3)
246
+
247
+ args.remain_time, args.finish_time = '-', time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() - 60))
248
+ print(f'final args:\n\n{str(args)}')
249
+ args.dump_log(); tb_lg.flush(); tb_lg.close()
250
+ dist.barrier()
251
+
252
+
253
+ def train_one_ep(ep: int, is_first_ep: bool, start_it: int, args: arg_util.Args, tb_lg: misc.TensorboardLogger, ld_or_itrt, iters_train: int, trainer):
254
+ # import heavy packages after Dataloader object creation
255
+ from trainer import VARTrainer
256
+ from utils.lr_control import lr_wd_annealing
257
+ trainer: VARTrainer
258
+
259
+ step_cnt = 0
260
+ me = misc.MetricLogger(delimiter=' ')
261
+ me.add_meter('tlr', misc.SmoothedValue(window_size=1, fmt='{value:.2g}'))
262
+ me.add_meter('tnm', misc.SmoothedValue(window_size=1, fmt='{value:.2f}'))
263
+ [me.add_meter(x, misc.SmoothedValue(fmt='{median:.3f} ({global_avg:.3f})')) for x in ['Lm', 'Lt']]
264
+ [me.add_meter(x, misc.SmoothedValue(fmt='{median:.2f} ({global_avg:.2f})')) for x in ['Accm', 'Acct']]
265
+ header = f'[Ep]: [{ep:4d}/{args.ep}]'
266
+
267
+ if is_first_ep:
268
+ warnings.filterwarnings('ignore', category=DeprecationWarning)
269
+ warnings.filterwarnings('ignore', category=UserWarning)
270
+ g_it, max_it = ep * iters_train, args.ep * iters_train
271
+
272
+ for it, (inp, label) in me.log_every(start_it, iters_train, ld_or_itrt, 30 if iters_train > 8000 else 5, header):
273
+ g_it = ep * iters_train + it
274
+ if it < start_it: continue
275
+ if is_first_ep and it == start_it: warnings.resetwarnings()
276
+
277
+ inp = inp.to(args.device, non_blocking=True)
278
+ label = label.to(args.device, non_blocking=True)
279
+
280
+ args.cur_it = f'{it+1}/{iters_train}'
281
+
282
+ wp_it = args.wp * iters_train
283
+ min_tlr, max_tlr, min_twd, max_twd = lr_wd_annealing(args.sche, trainer.var_opt.optimizer, args.tlr, args.twd, args.twde, g_it, wp_it, max_it, wp0=args.wp0, wpe=args.wpe)
284
+ args.cur_lr, args.cur_wd = max_tlr, max_twd
285
+
286
+ if args.pg: # default: args.pg == 0.0, means no progressive training, won't get into this
287
+ if g_it <= wp_it: prog_si = args.pg0
288
+ elif g_it >= max_it*args.pg: prog_si = len(args.patch_nums) - 1
289
+ else:
290
+ delta = len(args.patch_nums) - 1 - args.pg0
291
+ progress = min(max((g_it - wp_it) / (max_it*args.pg - wp_it), 0), 1) # from 0 to 1
292
+ prog_si = args.pg0 + round(progress * delta) # from args.pg0 to len(args.patch_nums)-1
293
+ else:
294
+ prog_si = -1
295
+
296
+ stepping = (g_it + 1) % args.ac == 0
297
+ step_cnt += int(stepping)
298
+
299
+ grad_norm, scale_log2 = trainer.train_step(
300
+ it=it, g_it=g_it, stepping=stepping, metric_lg=me, tb_lg=tb_lg,
301
+ inp_B3HW=inp, label_B=label, prog_si=prog_si, prog_wp_it=args.pgwp * iters_train,
302
+ )
303
+
304
+ me.update(tlr=max_tlr)
305
+ tb_lg.set_step(step=g_it)
306
+ tb_lg.update(head='AR_opt_lr/lr_min', sche_tlr=min_tlr)
307
+ tb_lg.update(head='AR_opt_lr/lr_max', sche_tlr=max_tlr)
308
+ tb_lg.update(head='AR_opt_wd/wd_max', sche_twd=max_twd)
309
+ tb_lg.update(head='AR_opt_wd/wd_min', sche_twd=min_twd)
310
+ tb_lg.update(head='AR_opt_grad/fp16', scale_log2=scale_log2)
311
+
312
+ if args.tclip > 0:
313
+ tb_lg.update(head='AR_opt_grad/grad', grad_norm=grad_norm)
314
+ tb_lg.update(head='AR_opt_grad/grad', grad_clip=args.tclip)
315
+
316
+ me.synchronize_between_processes()
317
+ return {k: meter.global_avg for k, meter in me.meters.items()}, me.iter_time.time_preds(max_it - (g_it + 1) + (args.ep - ep) * 15) # +15: other cost
318
+
319
+
320
+ class NullDDP(torch.nn.Module):
321
+ def __init__(self, module, *args, **kwargs):
322
+ super(NullDDP, self).__init__()
323
+ self.module = module
324
+ self.require_backward_grad_sync = False
325
+
326
+ def forward(self, *args, **kwargs):
327
+ return self.module(*args, **kwargs)
328
+
329
+
330
+ if __name__ == '__main__':
331
+ try: main_training()
332
+ finally:
333
+ dist.finalize()
334
+ if isinstance(sys.stdout, misc.SyncPrint) and isinstance(sys.stderr, misc.SyncPrint):
335
+ sys.stdout.close(), sys.stderr.close()
VAR/trainer.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn.parallel import DistributedDataParallel as DDP
7
+ from torch.utils.data import DataLoader
8
+
9
+ import dist
10
+ from models import VAR, VQVAE, VectorQuantizer2
11
+ from utils.amp_sc import AmpOptimizer
12
+ from utils.misc import MetricLogger, TensorboardLogger
13
+
14
+ Ten = torch.Tensor
15
+ FTen = torch.Tensor
16
+ ITen = torch.LongTensor
17
+ BTen = torch.BoolTensor
18
+
19
+
20
+ class VARTrainer(object):
21
+ def __init__(
22
+ self, device, patch_nums: Tuple[int, ...], resos: Tuple[int, ...],
23
+ vae_local: VQVAE, var_wo_ddp: VAR, var: DDP,
24
+ var_opt: AmpOptimizer, label_smooth: float,
25
+ ):
26
+ super(VARTrainer, self).__init__()
27
+
28
+ self.var, self.vae_local, self.quantize_local = var, vae_local, vae_local.quantize
29
+ self.quantize_local: VectorQuantizer2
30
+ self.var_wo_ddp: VAR = var_wo_ddp # after torch.compile
31
+ self.var_opt = var_opt
32
+
33
+ del self.var_wo_ddp.rng
34
+ self.var_wo_ddp.rng = torch.Generator(device=device)
35
+
36
+ self.label_smooth = label_smooth
37
+ self.train_loss = nn.CrossEntropyLoss(label_smoothing=label_smooth, reduction='none')
38
+ self.val_loss = nn.CrossEntropyLoss(label_smoothing=0.0, reduction='mean')
39
+ self.L = sum(pn * pn for pn in patch_nums)
40
+ self.last_l = patch_nums[-1] * patch_nums[-1]
41
+ self.loss_weight = torch.ones(1, self.L, device=device) / self.L
42
+
43
+ self.patch_nums, self.resos = patch_nums, resos
44
+ self.begin_ends = []
45
+ cur = 0
46
+ for i, pn in enumerate(patch_nums):
47
+ self.begin_ends.append((cur, cur + pn * pn))
48
+ cur += pn*pn
49
+
50
+ self.prog_it = 0
51
+ self.last_prog_si = -1
52
+ self.first_prog = True
53
+
54
+ @torch.no_grad()
55
+ def eval_ep(self, ld_val: DataLoader):
56
+ tot = 0
57
+ L_mean, L_tail, acc_mean, acc_tail = 0, 0, 0, 0
58
+ stt = time.time()
59
+ training = self.var_wo_ddp.training
60
+ self.var_wo_ddp.eval()
61
+ for inp_B3HW, label_B in ld_val:
62
+ B, V = label_B.shape[0], self.vae_local.vocab_size
63
+ inp_B3HW = inp_B3HW.to(dist.get_device(), non_blocking=True)
64
+ label_B = label_B.to(dist.get_device(), non_blocking=True)
65
+
66
+ gt_idx_Bl: List[ITen] = self.vae_local.img_to_idxBl(inp_B3HW)
67
+ gt_BL = torch.cat(gt_idx_Bl, dim=1)
68
+ x_BLCv_wo_first_l: Ten = self.quantize_local.idxBl_to_var_input(gt_idx_Bl)
69
+
70
+ self.var_wo_ddp.forward
71
+ logits_BLV = self.var_wo_ddp(label_B, x_BLCv_wo_first_l)
72
+ L_mean += self.val_loss(logits_BLV.data.view(-1, V), gt_BL.view(-1)) * B
73
+ L_tail += self.val_loss(logits_BLV.data[:, -self.last_l:].reshape(-1, V), gt_BL[:, -self.last_l:].reshape(-1)) * B
74
+ acc_mean += (logits_BLV.data.argmax(dim=-1) == gt_BL).sum() * (100/gt_BL.shape[1])
75
+ acc_tail += (logits_BLV.data[:, -self.last_l:].argmax(dim=-1) == gt_BL[:, -self.last_l:]).sum() * (100 / self.last_l)
76
+ tot += B
77
+ self.var_wo_ddp.train(training)
78
+
79
+ stats = L_mean.new_tensor([L_mean.item(), L_tail.item(), acc_mean.item(), acc_tail.item(), tot])
80
+ dist.allreduce(stats)
81
+ tot = round(stats[-1].item())
82
+ stats /= tot
83
+ L_mean, L_tail, acc_mean, acc_tail, _ = stats.tolist()
84
+ return L_mean, L_tail, acc_mean, acc_tail, tot, time.time()-stt
85
+
86
+ def train_step(
87
+ self, it: int, g_it: int, stepping: bool, metric_lg: MetricLogger, tb_lg: TensorboardLogger,
88
+ inp_B3HW: FTen, label_B: Union[ITen, FTen], prog_si: int, prog_wp_it: float,
89
+ ) -> Tuple[Optional[Union[Ten, float]], Optional[float]]:
90
+ # if progressive training
91
+ self.var_wo_ddp.prog_si = self.vae_local.quantize.prog_si = prog_si
92
+ if self.last_prog_si != prog_si:
93
+ if self.last_prog_si != -1: self.first_prog = False
94
+ self.last_prog_si = prog_si
95
+ self.prog_it = 0
96
+ self.prog_it += 1
97
+ prog_wp = max(min(self.prog_it / prog_wp_it, 1), 0.01)
98
+ if self.first_prog: prog_wp = 1 # no prog warmup at first prog stage, as it's already solved in wp
99
+ if prog_si == len(self.patch_nums) - 1: prog_si = -1 # max prog, as if no prog
100
+
101
+ # forward
102
+ B, V = label_B.shape[0], self.vae_local.vocab_size
103
+ self.var.require_backward_grad_sync = stepping
104
+
105
+ gt_idx_Bl: List[ITen] = self.vae_local.img_to_idxBl(inp_B3HW)
106
+ gt_BL = torch.cat(gt_idx_Bl, dim=1)
107
+ x_BLCv_wo_first_l: Ten = self.quantize_local.idxBl_to_var_input(gt_idx_Bl)
108
+
109
+ with self.var_opt.amp_ctx:
110
+ self.var_wo_ddp.forward
111
+ logits_BLV = self.var(label_B, x_BLCv_wo_first_l)
112
+ loss = self.train_loss(logits_BLV.view(-1, V), gt_BL.view(-1)).view(B, -1)
113
+ if prog_si >= 0: # in progressive training
114
+ bg, ed = self.begin_ends[prog_si]
115
+ assert logits_BLV.shape[1] == gt_BL.shape[1] == ed
116
+ lw = self.loss_weight[:, :ed].clone()
117
+ lw[:, bg:ed] *= min(max(prog_wp, 0), 1)
118
+ else: # not in progressive training
119
+ lw = self.loss_weight
120
+ loss = loss.mul(lw).sum(dim=-1).mean()
121
+
122
+ # backward
123
+ grad_norm, scale_log2 = self.var_opt.backward_clip_step(loss=loss, stepping=stepping)
124
+
125
+ # log
126
+ pred_BL = logits_BLV.data.argmax(dim=-1)
127
+ if it == 0 or it in metric_lg.log_iters:
128
+ Lmean = self.val_loss(logits_BLV.data.view(-1, V), gt_BL.view(-1)).item()
129
+ acc_mean = (pred_BL == gt_BL).float().mean().item() * 100
130
+ if prog_si >= 0: # in progressive training
131
+ Ltail = acc_tail = -1
132
+ else: # not in progressive training
133
+ Ltail = self.val_loss(logits_BLV.data[:, -self.last_l:].reshape(-1, V), gt_BL[:, -self.last_l:].reshape(-1)).item()
134
+ acc_tail = (pred_BL[:, -self.last_l:] == gt_BL[:, -self.last_l:]).float().mean().item() * 100
135
+ grad_norm = grad_norm.item()
136
+ metric_lg.update(Lm=Lmean, Lt=Ltail, Accm=acc_mean, Acct=acc_tail, tnm=grad_norm)
137
+
138
+ # log to tensorboard
139
+ if g_it == 0 or (g_it + 1) % 500 == 0:
140
+ prob_per_class_is_chosen = pred_BL.view(-1).bincount(minlength=V).float()
141
+ dist.allreduce(prob_per_class_is_chosen)
142
+ prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
143
+ cluster_usage = (prob_per_class_is_chosen > 0.001 / V).float().mean().item() * 100
144
+ if dist.is_master():
145
+ if g_it == 0:
146
+ tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-10000)
147
+ tb_lg.update(head='AR_iter_loss', z_voc_usage=cluster_usage, step=-1000)
148
+ kw = dict(z_voc_usage=cluster_usage)
149
+ for si, (bg, ed) in enumerate(self.begin_ends):
150
+ if 0 <= prog_si < si: break
151
+ pred, tar = logits_BLV.data[:, bg:ed].reshape(-1, V), gt_BL[:, bg:ed].reshape(-1)
152
+ acc = (pred.argmax(dim=-1) == tar).float().mean().item() * 100
153
+ ce = self.val_loss(pred, tar).item()
154
+ kw[f'acc_{self.resos[si]}'] = acc
155
+ kw[f'L_{self.resos[si]}'] = ce
156
+ tb_lg.update(head='AR_iter_loss', **kw, step=g_it)
157
+ tb_lg.update(head='AR_iter_schedule', prog_a_reso=self.resos[prog_si], prog_si=prog_si, prog_wp=prog_wp, step=g_it)
158
+
159
+ self.var_wo_ddp.prog_si = self.vae_local.quantize.prog_si = -1
160
+ return grad_norm, scale_log2
161
+
162
+ def get_config(self):
163
+ return {
164
+ 'patch_nums': self.patch_nums, 'resos': self.resos,
165
+ 'label_smooth': self.label_smooth,
166
+ 'prog_it': self.prog_it, 'last_prog_si': self.last_prog_si, 'first_prog': self.first_prog,
167
+ }
168
+
169
+ def state_dict(self):
170
+ state = {'config': self.get_config()}
171
+ for k in ('var_wo_ddp', 'vae_local', 'var_opt'):
172
+ m = getattr(self, k)
173
+ if m is not None:
174
+ if hasattr(m, '_orig_mod'):
175
+ m = m._orig_mod
176
+ state[k] = m.state_dict()
177
+ return state
178
+
179
+ def load_state_dict(self, state, strict=True, skip_vae=False):
180
+ for k in ('var_wo_ddp', 'vae_local', 'var_opt'):
181
+ if skip_vae and 'vae' in k: continue
182
+ m = getattr(self, k)
183
+ if m is not None:
184
+ if hasattr(m, '_orig_mod'):
185
+ m = m._orig_mod
186
+ ret = m.load_state_dict(state[k], strict=strict)
187
+ if ret is not None:
188
+ missing, unexpected = ret
189
+ print(f'[VARTrainer.load_state_dict] {k} missing: {missing}')
190
+ print(f'[VARTrainer.load_state_dict] {k} unexpected: {unexpected}')
191
+
192
+ config: dict = state.pop('config', None)
193
+ self.prog_it = config.get('prog_it', 0)
194
+ self.last_prog_si = config.get('last_prog_si', -1)
195
+ self.first_prog = config.get('first_prog', True)
196
+ if config is not None:
197
+ for k, v in self.get_config().items():
198
+ if config.get(k, None) != v:
199
+ err = f'[VAR.load_state_dict] config mismatch: this.{k}={v} (ckpt.{k}={config.get(k, None)})'
200
+ if strict: raise AttributeError(err)
201
+ else: print(err)
VAR/utils/amp_sc.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+
6
+
7
+ class NullCtx:
8
+ def __enter__(self):
9
+ pass
10
+
11
+ def __exit__(self, exc_type, exc_val, exc_tb):
12
+ pass
13
+
14
+
15
+ class AmpOptimizer:
16
+ def __init__(
17
+ self,
18
+ mixed_precision: int,
19
+ optimizer: torch.optim.Optimizer, names: List[str], paras: List[torch.nn.Parameter],
20
+ grad_clip: float, n_gradient_accumulation: int = 1,
21
+ ):
22
+ self.enable_amp = mixed_precision > 0
23
+ self.using_fp16_rather_bf16 = mixed_precision == 1
24
+
25
+ if self.enable_amp:
26
+ self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=True)
27
+ self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000) if self.using_fp16_rather_bf16 else None # only fp16 needs a scaler
28
+ else:
29
+ self.amp_ctx = NullCtx()
30
+ self.scaler = None
31
+
32
+ self.optimizer, self.names, self.paras = optimizer, names, paras # paras have been filtered so everyone requires grad
33
+ self.grad_clip = grad_clip
34
+ self.early_clipping = self.grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm')
35
+ self.late_clipping = self.grad_clip > 0 and hasattr(optimizer, 'global_grad_norm')
36
+
37
+ self.r_accu = 1 / n_gradient_accumulation # r_accu == 1.0 / n_gradient_accumulation
38
+
39
+ def backward_clip_step(
40
+ self, stepping: bool, loss: torch.Tensor,
41
+ ) -> Tuple[Optional[Union[torch.Tensor, float]], Optional[float]]:
42
+ # backward
43
+ loss = loss.mul(self.r_accu) # r_accu == 1.0 / n_gradient_accumulation
44
+ orig_norm = scaler_sc = None
45
+ if self.scaler is not None:
46
+ self.scaler.scale(loss).backward(retain_graph=False, create_graph=False)
47
+ else:
48
+ loss.backward(retain_graph=False, create_graph=False)
49
+
50
+ if stepping:
51
+ if self.scaler is not None: self.scaler.unscale_(self.optimizer)
52
+ if self.early_clipping:
53
+ orig_norm = torch.nn.utils.clip_grad_norm_(self.paras, self.grad_clip)
54
+
55
+ if self.scaler is not None:
56
+ self.scaler.step(self.optimizer)
57
+ scaler_sc: float = self.scaler.get_scale()
58
+ if scaler_sc > 32768.: # fp16 will overflow when >65536, so multiply 32768 could be dangerous
59
+ self.scaler.update(new_scale=32768.)
60
+ else:
61
+ self.scaler.update()
62
+ try:
63
+ scaler_sc = float(math.log2(scaler_sc))
64
+ except Exception as e:
65
+ print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True)
66
+ raise e
67
+ else:
68
+ self.optimizer.step()
69
+
70
+ if self.late_clipping:
71
+ orig_norm = self.optimizer.global_grad_norm
72
+
73
+ self.optimizer.zero_grad(set_to_none=True)
74
+
75
+ return orig_norm, scaler_sc
76
+
77
+ def state_dict(self):
78
+ return {
79
+ 'optimizer': self.optimizer.state_dict()
80
+ } if self.scaler is None else {
81
+ 'scaler': self.scaler.state_dict(),
82
+ 'optimizer': self.optimizer.state_dict()
83
+ }
84
+
85
+ def load_state_dict(self, state, strict=True):
86
+ if self.scaler is not None:
87
+ try: self.scaler.load_state_dict(state['scaler'])
88
+ except Exception as e: print(f'[fp16 load_state_dict err] {e}')
89
+ self.optimizer.load_state_dict(state['optimizer'])
VAR/utils/arg_util.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ import re
5
+ import subprocess
6
+ import sys
7
+ import time
8
+ from collections import OrderedDict
9
+ from typing import Optional, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ try:
15
+ from tap import Tap
16
+ except ImportError as e:
17
+ print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True)
18
+ print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True)
19
+ time.sleep(5)
20
+ raise e
21
+
22
+ import dist
23
+
24
+
25
+ class Args(Tap):
26
+ data_path: str = '/path/to/imagenet'
27
+ exp_name: str = 'text'
28
+
29
+ # VAE
30
+ vfast: int = 0 # torch.compile VAE; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'
31
+ # VAR
32
+ tfast: int = 0 # torch.compile VAR; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'
33
+ depth: int = 16 # VAR depth
34
+ # VAR initialization
35
+ ini: float = -1 # -1: automated model parameter initialization
36
+ hd: float = 0.02 # head.w *= hd
37
+ aln: float = 0.5 # the multiplier of ada_lin.w's initialization
38
+ alng: float = 1e-5 # the multiplier of ada_lin.w[gamma channels]'s initialization
39
+ # VAR optimization
40
+ fp16: int = 0 # 1: using fp16, 2: bf16
41
+ tblr: float = 1e-4 # base lr
42
+ tlr: float = None # lr = base lr * (bs / 256)
43
+ twd: float = 0.05 # initial wd
44
+ twde: float = 0 # final wd, =twde or twd
45
+ tclip: float = 2. # <=0 for not using grad clip
46
+ ls: float = 0.0 # label smooth
47
+
48
+ bs: int = 768 # global batch size
49
+ batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size() / 8) * 8
50
+ glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size()
51
+ ac: int = 1 # gradient accumulation
52
+
53
+ ep: int = 250
54
+ wp: float = 0
55
+ wp0: float = 0.005 # initial lr ratio at the begging of lr warm up
56
+ wpe: float = 0.01 # final lr ratio at the end of training
57
+ sche: str = 'lin0' # lr schedule
58
+
59
+ opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5 (0.25x) wd=0.8 (8x); Lion needs a large bs to work
60
+ afuse: bool = True # fused adamw
61
+
62
+ # other hps
63
+ saln: bool = False # whether to use shared adaln
64
+ anorm: bool = True # whether to use L2 normalized attention
65
+ fuse: bool = True # whether to use fused op like flash attn, xformers, fused MLP, fused LayerNorm, etc.
66
+
67
+ # data
68
+ pn: str = '1_2_3_4_5_6_8_10_13_16'
69
+ patch_size: int = 16
70
+ patch_nums: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_')))
71
+ resos: tuple = None # [automatically set; don't specify this] = tuple(pn * args.patch_size for pn in args.patch_nums)
72
+
73
+ data_load_reso: int = None # [automatically set; don't specify this] would be max(patch_nums) * patch_size
74
+ mid_reso: float = 1.125 # aug: first resize to mid_reso = 1.125 * data_load_reso, then crop to data_load_reso
75
+ hflip: bool = False # augmentation: horizontal flip
76
+ workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader
77
+
78
+ # progressive training
79
+ pg: float = 0.0 # >0 for use progressive training during [0%, this] of training
80
+ pg0: int = 4 # progressive initial stage, 0: from the 1st token map, 1: from the 2nd token map, etc
81
+ pgwp: float = 0 # num of warmup epochs at each progressive stage
82
+
83
+ # would be automatically set in runtime
84
+ cmd: str = ' '.join(sys.argv[1:]) # [automatically set; don't specify this]
85
+ branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
86
+ commit_id: str = subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
87
+ commit_msg: str = (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this]
88
+ acc_mean: float = None # [automatically set; don't specify this]
89
+ acc_tail: float = None # [automatically set; don't specify this]
90
+ L_mean: float = None # [automatically set; don't specify this]
91
+ L_tail: float = None # [automatically set; don't specify this]
92
+ vacc_mean: float = None # [automatically set; don't specify this]
93
+ vacc_tail: float = None # [automatically set; don't specify this]
94
+ vL_mean: float = None # [automatically set; don't specify this]
95
+ vL_tail: float = None # [automatically set; don't specify this]
96
+ grad_norm: float = None # [automatically set; don't specify this]
97
+ cur_lr: float = None # [automatically set; don't specify this]
98
+ cur_wd: float = None # [automatically set; don't specify this]
99
+ cur_it: str = '' # [automatically set; don't specify this]
100
+ cur_ep: str = '' # [automatically set; don't specify this]
101
+ remain_time: str = '' # [automatically set; don't specify this]
102
+ finish_time: str = '' # [automatically set; don't specify this]
103
+
104
+ # environment
105
+ local_out_dir_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # [automatically set; don't specify this]
106
+ tb_log_dir_path: str = '...tb-...' # [automatically set; don't specify this]
107
+ log_txt_path: str = '...' # [automatically set; don't specify this]
108
+ last_ckpt_path: str = '...' # [automatically set; don't specify this]
109
+
110
+ tf32: bool = True # whether to use TensorFloat32
111
+ device: str = 'cpu' # [automatically set; don't specify this]
112
+ seed: int = None # seed
113
+ def seed_everything(self, benchmark: bool):
114
+ torch.backends.cudnn.enabled = True
115
+ torch.backends.cudnn.benchmark = benchmark
116
+ if self.seed is None:
117
+ torch.backends.cudnn.deterministic = False
118
+ else:
119
+ torch.backends.cudnn.deterministic = True
120
+ seed = self.seed * dist.get_world_size() + dist.get_rank()
121
+ os.environ['PYTHONHASHSEED'] = str(seed)
122
+ random.seed(seed)
123
+ np.random.seed(seed)
124
+ torch.manual_seed(seed)
125
+ if torch.cuda.is_available():
126
+ torch.cuda.manual_seed(seed)
127
+ torch.cuda.manual_seed_all(seed)
128
+ same_seed_for_all_ranks: int = 0 # this is only for distributed sampler
129
+ def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation
130
+ if self.seed is None:
131
+ return None
132
+ g = torch.Generator()
133
+ g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank())
134
+ return g
135
+
136
+ local_debug: bool = 'KEVIN_LOCAL' in os.environ
137
+ dbg_nan: bool = False # 'KEVIN_LOCAL' in os.environ
138
+
139
+ def compile_model(self, m, fast):
140
+ if fast == 0 or self.local_debug:
141
+ return m
142
+ return torch.compile(m, mode={
143
+ 1: 'reduce-overhead',
144
+ 2: 'max-autotune',
145
+ 3: 'default',
146
+ }[fast]) if hasattr(torch, 'compile') else m
147
+
148
+ def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
149
+ d = (OrderedDict if key_ordered else dict)()
150
+ # self.as_dict() would contain methods, but we only need variables
151
+ for k in self.class_variables.keys():
152
+ if k not in {'device'}: # these are not serializable
153
+ d[k] = getattr(self, k)
154
+ return d
155
+
156
+ def load_state_dict(self, d: Union[OrderedDict, dict, str]):
157
+ if isinstance(d, str): # for compatibility with old version
158
+ d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l]))
159
+ for k in d.keys():
160
+ try:
161
+ setattr(self, k, d[k])
162
+ except Exception as e:
163
+ print(f'k={k}, v={d[k]}')
164
+ raise e
165
+
166
+ @staticmethod
167
+ def set_tf32(tf32: bool):
168
+ if torch.cuda.is_available():
169
+ torch.backends.cudnn.allow_tf32 = bool(tf32)
170
+ torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
171
+ if hasattr(torch, 'set_float32_matmul_precision'):
172
+ torch.set_float32_matmul_precision('high' if tf32 else 'highest')
173
+ print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
174
+ print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
175
+ print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
176
+
177
+ def dump_log(self):
178
+ if not dist.is_local_master():
179
+ return
180
+ if '1/' in self.cur_ep: # first time to dump log
181
+ with open(self.log_txt_path, 'w') as fp:
182
+ json.dump({'is_master': dist.is_master(), 'name': self.exp_name, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch, 'tb_log_dir_path': self.tb_log_dir_path}, fp, indent=0)
183
+ fp.write('\n')
184
+
185
+ log_dict = {}
186
+ for k, v in {
187
+ 'it': self.cur_it, 'ep': self.cur_ep,
188
+ 'lr': self.cur_lr, 'wd': self.cur_wd, 'grad_norm': self.grad_norm,
189
+ 'L_mean': self.L_mean, 'L_tail': self.L_tail, 'acc_mean': self.acc_mean, 'acc_tail': self.acc_tail,
190
+ 'vL_mean': self.vL_mean, 'vL_tail': self.vL_tail, 'vacc_mean': self.vacc_mean, 'vacc_tail': self.vacc_tail,
191
+ 'remain_time': self.remain_time, 'finish_time': self.finish_time,
192
+ }.items():
193
+ if hasattr(v, 'item'): v = v.item()
194
+ log_dict[k] = v
195
+ with open(self.log_txt_path, 'a') as fp:
196
+ fp.write(f'{log_dict}\n')
197
+
198
+ def __str__(self):
199
+ s = []
200
+ for k in self.class_variables.keys():
201
+ if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
202
+ s.append(f' {k:20s}: {getattr(self, k)}')
203
+ s = '\n'.join(s)
204
+ return f'{{\n{s}\n}}\n'
205
+
206
+
207
+ def init_dist_and_get_args():
208
+ for i in range(len(sys.argv)):
209
+ if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
210
+ del sys.argv[i]
211
+ break
212
+ args = Args(explicit_bool=True).parse_args(known_only=True)
213
+ if args.local_debug:
214
+ args.pn = '1_2_3'
215
+ args.seed = 1
216
+ args.aln = 1e-2
217
+ args.alng = 1e-5
218
+ args.saln = False
219
+ args.afuse = False
220
+ args.pg = 0.8
221
+ args.pg0 = 1
222
+ else:
223
+ if args.data_path == '/path/to/imagenet':
224
+ raise ValueError(f'{"*"*40} please specify --data_path=/path/to/imagenet {"*"*40}')
225
+
226
+ # warn args.extra_args
227
+ if len(args.extra_args) > 0:
228
+ print(f'======================================================================================')
229
+ print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
230
+ print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
231
+ print(f'======================================================================================\n\n')
232
+
233
+ # init torch distributed
234
+ from utils import misc
235
+ os.makedirs(args.local_out_dir_path, exist_ok=True)
236
+ misc.init_distributed_mode(local_out_path=args.local_out_dir_path, timeout=30)
237
+
238
+ # set env
239
+ args.set_tf32(args.tf32)
240
+ args.seed_everything(benchmark=args.pg == 0)
241
+
242
+ # update args: data loading
243
+ args.device = dist.get_device()
244
+ if args.pn == '256':
245
+ args.pn = '1_2_3_4_5_6_8_10_13_16'
246
+ elif args.pn == '512':
247
+ args.pn = '1_2_3_4_6_9_13_18_24_32'
248
+ elif args.pn == '1024':
249
+ args.pn = '1_2_3_4_5_7_9_12_16_21_27_36_48_64'
250
+ args.patch_nums = tuple(map(int, args.pn.replace('-', '_').split('_')))
251
+ args.resos = tuple(pn * args.patch_size for pn in args.patch_nums)
252
+ args.data_load_reso = max(args.resos)
253
+
254
+ # update args: bs and lr
255
+ bs_per_gpu = round(args.bs / args.ac / dist.get_world_size())
256
+ args.batch_size = bs_per_gpu
257
+ args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size()
258
+ args.workers = min(max(0, args.workers), args.batch_size)
259
+
260
+ args.tlr = args.ac * args.tblr * args.glb_batch_size / 256
261
+ args.twde = args.twde or args.twd
262
+
263
+ if args.wp == 0:
264
+ args.wp = args.ep * 1/50
265
+
266
+ # update args: progressive training
267
+ if args.pgwp == 0:
268
+ args.pgwp = args.ep * 1/300
269
+ if args.pg > 0:
270
+ args.sche = f'lin{args.pg:g}'
271
+
272
+ # update args: paths
273
+ args.log_txt_path = os.path.join(args.local_out_dir_path, 'log.txt')
274
+ args.last_ckpt_path = os.path.join(args.local_out_dir_path, f'ar-ckpt-last.pth')
275
+ _reg_valid_name = re.compile(r'[^\w\-+,.]')
276
+ tb_name = _reg_valid_name.sub(
277
+ '_',
278
+ f'tb-VARd{args.depth}'
279
+ f'__pn{args.pn}'
280
+ f'__b{args.bs}ep{args.ep}{args.opt[:4]}lr{args.tblr:g}wd{args.twd:g}'
281
+ )
282
+ args.tb_log_dir_path = os.path.join(args.local_out_dir_path, tb_name)
283
+
284
+ return args
VAR/utils/data.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+
3
+ import PIL.Image as PImage
4
+ from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS
5
+ from torchvision.transforms import InterpolationMode, transforms
6
+
7
+
8
+ def normalize_01_into_pm1(x): # normalize x from [0, 1] to [-1, 1] by (x*2) - 1
9
+ return x.add(x).add_(-1)
10
+
11
+
12
+ def build_dataset(
13
+ data_path: str, final_reso: int,
14
+ hflip=False, mid_reso=1.125,
15
+ ):
16
+ # build augmentations
17
+ mid_reso = round(mid_reso * final_reso) # first resize to mid_reso, then crop to final_reso
18
+ train_aug, val_aug = [
19
+ transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso
20
+ transforms.RandomCrop((final_reso, final_reso)),
21
+ transforms.ToTensor(), normalize_01_into_pm1,
22
+ ], [
23
+ transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso
24
+ transforms.CenterCrop((final_reso, final_reso)),
25
+ transforms.ToTensor(), normalize_01_into_pm1,
26
+ ]
27
+ if hflip: train_aug.insert(0, transforms.RandomHorizontalFlip())
28
+ train_aug, val_aug = transforms.Compose(train_aug), transforms.Compose(val_aug)
29
+
30
+ # build dataset
31
+ train_set = DatasetFolder(root=osp.join(data_path, 'train'), loader=pil_loader, extensions=IMG_EXTENSIONS, transform=train_aug)
32
+ val_set = DatasetFolder(root=osp.join(data_path, 'val'), loader=pil_loader, extensions=IMG_EXTENSIONS, transform=val_aug)
33
+ num_classes = 1000
34
+ print(f'[Dataset] {len(train_set)=}, {len(val_set)=}, {num_classes=}')
35
+ print_aug(train_aug, '[train]')
36
+ print_aug(val_aug, '[val]')
37
+
38
+ return num_classes, train_set, val_set
39
+
40
+
41
+ def pil_loader(path):
42
+ with open(path, 'rb') as f:
43
+ img: PImage.Image = PImage.open(f).convert('RGB')
44
+ return img
45
+
46
+
47
+ def print_aug(transform, label):
48
+ print(f'Transform {label} = ')
49
+ if hasattr(transform, 'transforms'):
50
+ for t in transform.transforms:
51
+ print(t)
52
+ else:
53
+ print(transform)
54
+ print('---------------------------\n')
VAR/utils/data_sampler.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.utils.data.sampler import Sampler
4
+
5
+
6
+ class EvalDistributedSampler(Sampler):
7
+ def __init__(self, dataset, num_replicas, rank):
8
+ seps = np.linspace(0, len(dataset), num_replicas+1, dtype=int)
9
+ beg, end = seps[:-1], seps[1:]
10
+ beg, end = beg[rank], end[rank]
11
+ self.indices = tuple(range(beg, end))
12
+
13
+ def __iter__(self):
14
+ return iter(self.indices)
15
+
16
+ def __len__(self) -> int:
17
+ return len(self.indices)
18
+
19
+
20
+ class InfiniteBatchSampler(Sampler):
21
+ def __init__(self, dataset_len, batch_size, seed_for_all_rank=0, fill_last=False, shuffle=True, drop_last=False, start_ep=0, start_it=0):
22
+ self.dataset_len = dataset_len
23
+ self.batch_size = batch_size
24
+ self.iters_per_ep = dataset_len // batch_size if drop_last else (dataset_len + batch_size - 1) // batch_size
25
+ self.max_p = self.iters_per_ep * batch_size
26
+ self.fill_last = fill_last
27
+ self.shuffle = shuffle
28
+ self.epoch = start_ep
29
+ self.same_seed_for_all_ranks = seed_for_all_rank
30
+ self.indices = self.gener_indices()
31
+ self.start_ep, self.start_it = start_ep, start_it
32
+
33
+ def gener_indices(self):
34
+ if self.shuffle:
35
+ g = torch.Generator()
36
+ g.manual_seed(self.epoch + self.same_seed_for_all_ranks)
37
+ indices = torch.randperm(self.dataset_len, generator=g).numpy()
38
+ else:
39
+ indices = torch.arange(self.dataset_len).numpy()
40
+
41
+ tails = self.batch_size - (self.dataset_len % self.batch_size)
42
+ if tails != self.batch_size and self.fill_last:
43
+ tails = indices[:tails]
44
+ np.random.shuffle(indices)
45
+ indices = np.concatenate((indices, tails))
46
+
47
+ # built-in list/tuple is faster than np.ndarray (when collating the data via a for-loop)
48
+ # noinspection PyTypeChecker
49
+ return tuple(indices.tolist())
50
+
51
+ def __iter__(self):
52
+ self.epoch = self.start_ep
53
+ while True:
54
+ self.epoch += 1
55
+ p = (self.start_it * self.batch_size) if self.epoch == self.start_ep else 0
56
+ while p < self.max_p:
57
+ q = p + self.batch_size
58
+ yield self.indices[p:q]
59
+ p = q
60
+ if self.shuffle:
61
+ self.indices = self.gener_indices()
62
+
63
+ def __len__(self):
64
+ return self.iters_per_ep
65
+
66
+
67
+ class DistInfiniteBatchSampler(InfiniteBatchSampler):
68
+ def __init__(self, world_size, rank, dataset_len, glb_batch_size, same_seed_for_all_ranks=0, repeated_aug=0, fill_last=False, shuffle=True, start_ep=0, start_it=0):
69
+ assert glb_batch_size % world_size == 0
70
+ self.world_size, self.rank = world_size, rank
71
+ self.dataset_len = dataset_len
72
+ self.glb_batch_size = glb_batch_size
73
+ self.batch_size = glb_batch_size // world_size
74
+
75
+ self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size
76
+ self.fill_last = fill_last
77
+ self.shuffle = shuffle
78
+ self.repeated_aug = repeated_aug
79
+ self.epoch = start_ep
80
+ self.same_seed_for_all_ranks = same_seed_for_all_ranks
81
+ self.indices = self.gener_indices()
82
+ self.start_ep, self.start_it = start_ep, start_it
83
+
84
+ def gener_indices(self):
85
+ global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0
86
+ # print(f'global_max_p = iters_per_ep({self.iters_per_ep}) * glb_batch_size({self.glb_batch_size}) = {global_max_p}')
87
+ if self.shuffle:
88
+ g = torch.Generator()
89
+ g.manual_seed(self.epoch + self.same_seed_for_all_ranks)
90
+ global_indices = torch.randperm(self.dataset_len, generator=g)
91
+ if self.repeated_aug > 1:
92
+ global_indices = global_indices[:(self.dataset_len + self.repeated_aug - 1) // self.repeated_aug].repeat_interleave(self.repeated_aug, dim=0)[:global_max_p]
93
+ else:
94
+ global_indices = torch.arange(self.dataset_len)
95
+ filling = global_max_p - global_indices.shape[0]
96
+ if filling > 0 and self.fill_last:
97
+ global_indices = torch.cat((global_indices, global_indices[:filling]))
98
+ # global_indices = tuple(global_indices.numpy().tolist())
99
+
100
+ seps = torch.linspace(0, global_indices.shape[0], self.world_size + 1, dtype=torch.int)
101
+ local_indices = global_indices[seps[self.rank].item():seps[self.rank + 1].item()].tolist()
102
+ self.max_p = len(local_indices)
103
+ return local_indices
VAR/utils/lr_control.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from pprint import pformat
3
+ from typing import Tuple, List, Dict, Union
4
+
5
+ import torch.nn
6
+
7
+ import dist
8
+
9
+
10
+ def lr_wd_annealing(sche_type: str, optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it, wp0=0.005, wpe=0.001):
11
+ """Decay the learning rate with half-cycle cosine after warmup"""
12
+ wp_it = round(wp_it)
13
+
14
+ if cur_it < wp_it:
15
+ cur_lr = wp0 + (1-wp0) * cur_it / wp_it
16
+ else:
17
+ pasd = (cur_it - wp_it) / (max_it-1 - wp_it) # [0, 1]
18
+ rest = 1 - pasd # [1, 0]
19
+ if sche_type == 'cos':
20
+ cur_lr = wpe + (1-wpe) * (0.5 + 0.5 * math.cos(math.pi * pasd))
21
+ elif sche_type == 'lin':
22
+ T = 0.15; max_rest = 1-T
23
+ if pasd < T: cur_lr = 1
24
+ else: cur_lr = wpe + (1-wpe) * rest / max_rest # 1 to wpe
25
+ elif sche_type == 'lin0':
26
+ T = 0.05; max_rest = 1-T
27
+ if pasd < T: cur_lr = 1
28
+ else: cur_lr = wpe + (1-wpe) * rest / max_rest
29
+ elif sche_type == 'lin00':
30
+ cur_lr = wpe + (1-wpe) * rest
31
+ elif sche_type.startswith('lin'):
32
+ T = float(sche_type[3:]); max_rest = 1-T
33
+ wpe_mid = wpe + (1-wpe) * max_rest
34
+ wpe_mid = (1 + wpe_mid) / 2
35
+ if pasd < T: cur_lr = 1 + (wpe_mid-1) * pasd / T
36
+ else: cur_lr = wpe + (wpe_mid-wpe) * rest / max_rest
37
+ elif sche_type == 'exp':
38
+ T = 0.15; max_rest = 1-T
39
+ if pasd < T: cur_lr = 1
40
+ else:
41
+ expo = (pasd-T) / max_rest * math.log(wpe)
42
+ cur_lr = math.exp(expo)
43
+ else:
44
+ raise NotImplementedError(f'unknown sche_type {sche_type}')
45
+
46
+ cur_lr *= peak_lr
47
+ pasd = cur_it / (max_it-1)
48
+ cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * pasd))
49
+
50
+ inf = 1e6
51
+ min_lr, max_lr = inf, -1
52
+ min_wd, max_wd = inf, -1
53
+ for param_group in optimizer.param_groups:
54
+ param_group['lr'] = cur_lr * param_group.get('lr_sc', 1) # 'lr_sc' could be assigned
55
+ max_lr = max(max_lr, param_group['lr'])
56
+ min_lr = min(min_lr, param_group['lr'])
57
+
58
+ param_group['weight_decay'] = cur_wd * param_group.get('wd_sc', 1)
59
+ max_wd = max(max_wd, param_group['weight_decay'])
60
+ if param_group['weight_decay'] > 0:
61
+ min_wd = min(min_wd, param_group['weight_decay'])
62
+
63
+ if min_lr == inf: min_lr = -1
64
+ if min_wd == inf: min_wd = -1
65
+ return min_lr, max_lr, min_wd, max_wd
66
+
67
+
68
+ def filter_params(model, nowd_keys=()) -> Tuple[
69
+ List[str], List[torch.nn.Parameter], List[Dict[str, Union[torch.nn.Parameter, float]]]
70
+ ]:
71
+ para_groups, para_groups_dbg = {}, {}
72
+ names, paras = [], []
73
+ names_no_grad = []
74
+ count, numel = 0, 0
75
+ for name, para in model.named_parameters():
76
+ name = name.replace('_fsdp_wrapped_module.', '')
77
+ if not para.requires_grad:
78
+ names_no_grad.append(name)
79
+ continue # frozen weights
80
+ count += 1
81
+ numel += para.numel()
82
+ names.append(name)
83
+ paras.append(para)
84
+
85
+ if para.ndim == 1 or name.endswith('bias') or any(k in name for k in nowd_keys):
86
+ cur_wd_sc, group_name = 0., 'ND'
87
+ else:
88
+ cur_wd_sc, group_name = 1., 'D'
89
+ cur_lr_sc = 1.
90
+ if group_name not in para_groups:
91
+ para_groups[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
92
+ para_groups_dbg[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
93
+ para_groups[group_name]['params'].append(para)
94
+ para_groups_dbg[group_name]['params'].append(name)
95
+
96
+ for g in para_groups_dbg.values():
97
+ g['params'] = pformat(', '.join(g['params']), width=200)
98
+
99
+ print(f'[get_param_groups] param_groups = \n{pformat(para_groups_dbg, indent=2, width=240)}\n')
100
+
101
+ for rk in range(dist.get_world_size()):
102
+ dist.barrier()
103
+ if dist.get_rank() == rk:
104
+ print(f'[get_param_groups][rank{dist.get_rank()}] {type(model).__name__=} {count=}, {numel=}', flush=True, force=True)
105
+ print('')
106
+
107
+ assert len(names_no_grad) == 0, f'[get_param_groups] names_no_grad = \n{pformat(names_no_grad, indent=2, width=240)}\n'
108
+ return names, paras, list(para_groups.values())
VAR/utils/misc.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import functools
3
+ import glob
4
+ import os
5
+ import subprocess
6
+ import sys
7
+ import time
8
+ from collections import defaultdict, deque
9
+ from typing import Iterator, List, Tuple
10
+
11
+ import numpy as np
12
+ import pytz
13
+ import torch
14
+ import torch.distributed as tdist
15
+
16
+ import dist
17
+ from utils import arg_util
18
+
19
+ os_system = functools.partial(subprocess.call, shell=True)
20
+ def echo(info):
21
+ os_system(f'echo "[$(date "+%m-%d-%H:%M:%S")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}"')
22
+ def os_system_get_stdout(cmd):
23
+ return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')
24
+ def os_system_get_stdout_stderr(cmd):
25
+ cnt = 0
26
+ while True:
27
+ try:
28
+ sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=30)
29
+ except subprocess.TimeoutExpired:
30
+ cnt += 1
31
+ print(f'[fetch free_port file] timeout cnt={cnt}')
32
+ else:
33
+ return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8')
34
+
35
+
36
+ def time_str(fmt='[%m-%d %H:%M:%S]'):
37
+ return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt)
38
+
39
+
40
+ def init_distributed_mode(local_out_path, only_sync_master=False, timeout=30):
41
+ try:
42
+ dist.initialize(fork=False, timeout=timeout)
43
+ dist.barrier()
44
+ except RuntimeError:
45
+ print(f'{">"*75} NCCL Error {"<"*75}', flush=True)
46
+ time.sleep(10)
47
+
48
+ if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True)
49
+ _change_builtin_print(dist.is_local_master())
50
+ if (dist.is_master() if only_sync_master else dist.is_local_master()) and local_out_path is not None and len(local_out_path):
51
+ sys.stdout, sys.stderr = SyncPrint(local_out_path, sync_stdout=True), SyncPrint(local_out_path, sync_stdout=False)
52
+
53
+
54
+ def _change_builtin_print(is_master):
55
+ import builtins as __builtin__
56
+
57
+ builtin_print = __builtin__.print
58
+ if type(builtin_print) != type(open):
59
+ return
60
+
61
+ def prt(*args, **kwargs):
62
+ force = kwargs.pop('force', False)
63
+ clean = kwargs.pop('clean', False)
64
+ deeper = kwargs.pop('deeper', False)
65
+ if is_master or force:
66
+ if not clean:
67
+ f_back = sys._getframe().f_back
68
+ if deeper and f_back.f_back is not None:
69
+ f_back = f_back.f_back
70
+ file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
71
+ builtin_print(f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs)
72
+ else:
73
+ builtin_print(*args, **kwargs)
74
+
75
+ __builtin__.print = prt
76
+
77
+
78
+ class SyncPrint(object):
79
+ def __init__(self, local_output_dir, sync_stdout=True):
80
+ self.sync_stdout = sync_stdout
81
+ self.terminal_stream = sys.stdout if sync_stdout else sys.stderr
82
+ fname = os.path.join(local_output_dir, 'stdout.txt' if sync_stdout else 'stderr.txt')
83
+ existing = os.path.exists(fname)
84
+ self.file_stream = open(fname, 'a')
85
+ if existing:
86
+ self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str()} ' + '='*55 + '\n')
87
+ self.file_stream.flush()
88
+ self.enabled = True
89
+
90
+ def write(self, message):
91
+ self.terminal_stream.write(message)
92
+ self.file_stream.write(message)
93
+
94
+ def flush(self):
95
+ self.terminal_stream.flush()
96
+ self.file_stream.flush()
97
+
98
+ def close(self):
99
+ if not self.enabled:
100
+ return
101
+ self.enabled = False
102
+ self.file_stream.flush()
103
+ self.file_stream.close()
104
+ if self.sync_stdout:
105
+ sys.stdout = self.terminal_stream
106
+ sys.stdout.flush()
107
+ else:
108
+ sys.stderr = self.terminal_stream
109
+ sys.stderr.flush()
110
+
111
+ def __del__(self):
112
+ self.close()
113
+
114
+
115
+ class DistLogger(object):
116
+ def __init__(self, lg, verbose):
117
+ self._lg, self._verbose = lg, verbose
118
+
119
+ @staticmethod
120
+ def do_nothing(*args, **kwargs):
121
+ pass
122
+
123
+ def __getattr__(self, attr: str):
124
+ return getattr(self._lg, attr) if self._verbose else DistLogger.do_nothing
125
+
126
+
127
+ class TensorboardLogger(object):
128
+ def __init__(self, log_dir, filename_suffix):
129
+ try: import tensorflow_io as tfio
130
+ except: pass
131
+ from torch.utils.tensorboard import SummaryWriter
132
+ self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=filename_suffix)
133
+ self.step = 0
134
+
135
+ def set_step(self, step=None):
136
+ if step is not None:
137
+ self.step = step
138
+ else:
139
+ self.step += 1
140
+
141
+ def update(self, head='scalar', step=None, **kwargs):
142
+ for k, v in kwargs.items():
143
+ if v is None:
144
+ continue
145
+ # assert isinstance(v, (float, int)), type(v)
146
+ if step is None: # iter wise
147
+ it = self.step
148
+ if it == 0 or (it + 1) % 500 == 0:
149
+ if hasattr(v, 'item'): v = v.item()
150
+ self.writer.add_scalar(f'{head}/{k}', v, it)
151
+ else: # epoch wise
152
+ if hasattr(v, 'item'): v = v.item()
153
+ self.writer.add_scalar(f'{head}/{k}', v, step)
154
+
155
+ def log_tensor_as_distri(self, tag, tensor1d, step=None):
156
+ if step is None: # iter wise
157
+ step = self.step
158
+ loggable = step == 0 or (step + 1) % 500 == 0
159
+ else: # epoch wise
160
+ loggable = True
161
+ if loggable:
162
+ try:
163
+ self.writer.add_histogram(tag=tag, values=tensor1d, global_step=step)
164
+ except Exception as e:
165
+ print(f'[log_tensor_as_distri writer.add_histogram failed]: {e}')
166
+
167
+ def log_image(self, tag, img_chw, step=None):
168
+ if step is None: # iter wise
169
+ step = self.step
170
+ loggable = step == 0 or (step + 1) % 500 == 0
171
+ else: # epoch wise
172
+ loggable = True
173
+ if loggable:
174
+ self.writer.add_image(tag, img_chw, step, dataformats='CHW')
175
+
176
+ def flush(self):
177
+ self.writer.flush()
178
+
179
+ def close(self):
180
+ self.writer.close()
181
+
182
+
183
+ class SmoothedValue(object):
184
+ """Track a series of values and provide access to smoothed values over a
185
+ window or the global series average.
186
+ """
187
+
188
+ def __init__(self, window_size=30, fmt=None):
189
+ if fmt is None:
190
+ fmt = "{median:.4f} ({global_avg:.4f})"
191
+ self.deque = deque(maxlen=window_size)
192
+ self.total = 0.0
193
+ self.count = 0
194
+ self.fmt = fmt
195
+
196
+ def update(self, value, n=1):
197
+ self.deque.append(value)
198
+ self.count += n
199
+ self.total += value * n
200
+
201
+ def synchronize_between_processes(self):
202
+ """
203
+ Warning: does not synchronize the deque!
204
+ """
205
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
206
+ tdist.barrier()
207
+ tdist.all_reduce(t)
208
+ t = t.tolist()
209
+ self.count = int(t[0])
210
+ self.total = t[1]
211
+
212
+ @property
213
+ def median(self):
214
+ return np.median(self.deque) if len(self.deque) else 0
215
+
216
+ @property
217
+ def avg(self):
218
+ return sum(self.deque) / (len(self.deque) or 1)
219
+
220
+ @property
221
+ def global_avg(self):
222
+ return self.total / (self.count or 1)
223
+
224
+ @property
225
+ def max(self):
226
+ return max(self.deque)
227
+
228
+ @property
229
+ def value(self):
230
+ return self.deque[-1] if len(self.deque) else 0
231
+
232
+ def time_preds(self, counts) -> Tuple[float, str, str]:
233
+ remain_secs = counts * self.median
234
+ return remain_secs, str(datetime.timedelta(seconds=round(remain_secs))), time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() + remain_secs))
235
+
236
+ def __str__(self):
237
+ return self.fmt.format(
238
+ median=self.median,
239
+ avg=self.avg,
240
+ global_avg=self.global_avg,
241
+ max=self.max,
242
+ value=self.value)
243
+
244
+
245
+ class MetricLogger(object):
246
+ def __init__(self, delimiter=' '):
247
+ self.meters = defaultdict(SmoothedValue)
248
+ self.delimiter = delimiter
249
+ self.iter_end_t = time.time()
250
+ self.log_iters = []
251
+
252
+ def update(self, **kwargs):
253
+ for k, v in kwargs.items():
254
+ if v is None:
255
+ continue
256
+ if hasattr(v, 'item'): v = v.item()
257
+ # assert isinstance(v, (float, int)), type(v)
258
+ assert isinstance(v, (float, int))
259
+ self.meters[k].update(v)
260
+
261
+ def __getattr__(self, attr):
262
+ if attr in self.meters:
263
+ return self.meters[attr]
264
+ if attr in self.__dict__:
265
+ return self.__dict__[attr]
266
+ raise AttributeError("'{}' object has no attribute '{}'".format(
267
+ type(self).__name__, attr))
268
+
269
+ def __str__(self):
270
+ loss_str = []
271
+ for name, meter in self.meters.items():
272
+ if len(meter.deque):
273
+ loss_str.append(
274
+ "{}: {}".format(name, str(meter))
275
+ )
276
+ return self.delimiter.join(loss_str)
277
+
278
+ def synchronize_between_processes(self):
279
+ for meter in self.meters.values():
280
+ meter.synchronize_between_processes()
281
+
282
+ def add_meter(self, name, meter):
283
+ self.meters[name] = meter
284
+
285
+ def log_every(self, start_it, max_iters, itrt, print_freq, header=None):
286
+ self.log_iters = set(np.linspace(0, max_iters-1, print_freq, dtype=int).tolist())
287
+ self.log_iters.add(start_it)
288
+ if not header:
289
+ header = ''
290
+ start_time = time.time()
291
+ self.iter_end_t = time.time()
292
+ self.iter_time = SmoothedValue(fmt='{avg:.4f}')
293
+ self.data_time = SmoothedValue(fmt='{avg:.4f}')
294
+ space_fmt = ':' + str(len(str(max_iters))) + 'd'
295
+ log_msg = [
296
+ header,
297
+ '[{0' + space_fmt + '}/{1}]',
298
+ 'eta: {eta}',
299
+ '{meters}',
300
+ 'time: {time}',
301
+ 'data: {data}'
302
+ ]
303
+ log_msg = self.delimiter.join(log_msg)
304
+
305
+ if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'):
306
+ for i in range(start_it, max_iters):
307
+ obj = next(itrt)
308
+ self.data_time.update(time.time() - self.iter_end_t)
309
+ yield i, obj
310
+ self.iter_time.update(time.time() - self.iter_end_t)
311
+ if i in self.log_iters:
312
+ eta_seconds = self.iter_time.global_avg * (max_iters - i)
313
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
314
+ print(log_msg.format(
315
+ i, max_iters, eta=eta_string,
316
+ meters=str(self),
317
+ time=str(self.iter_time), data=str(self.data_time)), flush=True)
318
+ self.iter_end_t = time.time()
319
+ else:
320
+ if isinstance(itrt, int): itrt = range(itrt)
321
+ for i, obj in enumerate(itrt):
322
+ self.data_time.update(time.time() - self.iter_end_t)
323
+ yield i, obj
324
+ self.iter_time.update(time.time() - self.iter_end_t)
325
+ if i in self.log_iters:
326
+ eta_seconds = self.iter_time.global_avg * (max_iters - i)
327
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
328
+ print(log_msg.format(
329
+ i, max_iters, eta=eta_string,
330
+ meters=str(self),
331
+ time=str(self.iter_time), data=str(self.data_time)), flush=True)
332
+ self.iter_end_t = time.time()
333
+
334
+ total_time = time.time() - start_time
335
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
336
+ print('{} Total time: {} ({:.3f} s / it)'.format(
337
+ header, total_time_str, total_time / max_iters), flush=True)
338
+
339
+
340
+ def glob_with_latest_modified_first(pattern, recursive=False):
341
+ return sorted(glob.glob(pattern, recursive=recursive), key=os.path.getmtime, reverse=True)
342
+
343
+
344
+ def auto_resume(args: arg_util.Args, pattern='ckpt*.pth') -> Tuple[List[str], int, int, dict, dict]:
345
+ info = []
346
+ file = os.path.join(args.local_out_dir_path, pattern)
347
+ all_ckpt = glob_with_latest_modified_first(file)
348
+ if len(all_ckpt) == 0:
349
+ info.append(f'[auto_resume] no ckpt found @ {file}')
350
+ info.append(f'[auto_resume quit]')
351
+ return info, 0, 0, {}, {}
352
+ else:
353
+ info.append(f'[auto_resume] load ckpt from @ {all_ckpt[0]} ...')
354
+ ckpt = torch.load(all_ckpt[0], map_location='cpu')
355
+ ep, it = ckpt['epoch'], ckpt['iter']
356
+ info.append(f'[auto_resume success] resume from ep{ep}, it{it}')
357
+ return info, ep, it, ckpt['trainer'], ckpt['args']
358
+
359
+
360
+ def create_npz_from_sample_folder(sample_folder: str):
361
+ """
362
+ Builds a single .npz file from a folder of .png samples. Refer to DiT.
363
+ """
364
+ import os, glob
365
+ import numpy as np
366
+ from tqdm import tqdm
367
+ from PIL import Image
368
+
369
+ samples = []
370
+ pngs = glob.glob(os.path.join(sample_folder, '*.png')) + glob.glob(os.path.join(sample_folder, '*.PNG'))
371
+ assert len(pngs) == 50_000, f'{len(pngs)} png files found in {sample_folder}, but expected 50,000'
372
+ for png in tqdm(pngs, desc='Building .npz file from samples (png only)'):
373
+ with Image.open(png) as sample_pil:
374
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
375
+ samples.append(sample_np)
376
+ samples = np.stack(samples)
377
+ assert samples.shape == (50_000, samples.shape[1], samples.shape[2], 3)
378
+ npz_path = f'{sample_folder}.npz'
379
+ np.savez(npz_path, arr_0=samples)
380
+ print(f'Saved .npz file to {npz_path} [shape={samples.shape}].')
381
+ return npz_path
infrance_example.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from models import VQVAE, build_vae_var
3
+ from dataset.imagenet_dataset import get_train_transforms
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+
7
+
8
+ device = 'mps'
9
+ patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
10
+
11
+ vae, var = build_vae_var(
12
+ V=4096, Cvae=32, ch=160, share_quant_resi=4,
13
+ device=device, patch_nums=patch_nums,
14
+ num_classes=1000, depth=16, shared_aln=False,
15
+ )
16
+ var_ckpt='var_d16.pth'
17
+ vae_ckpt='vae_ch160v4096z32.pth'
18
+ var.load_state_dict(torch.load(var_ckpt, map_location=device), strict=True)
19
+ vae.load_state_dict(torch.load(vae_ckpt, map_location=device), strict=True)
model-step-step=32000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4222b988d8108a1288367e282654e1e7819dd4d1d34b4b1efecec924d94d718d
3
+ size 4161643659
vae_ch160v4096z32.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c3ec27ae28a3f87055e83211ea8cc8558bd1985d7b51742d074fb4c2fcf186c
3
+ size 436075834