Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# **Scaling Image Tokenizers with Grouped Spherical Quantization**
|
2 |
+
---
|
3 |
+
|
4 |
+
[Paper link](https://arxiv.org/abs/2412.02632) | [GITHUB REPO](https://github.com/HelmholtzAI-FZJ/flex_gen) [HF Checkpoints](https://huggingface.co/collections/HelmholtzAI-FZJ/grouped-spherical-quantization-674d6f9f548e472d0eaf179e)
|
5 |
+
|
6 |
+
In [GSQ](https://arxiv.org/abs/2412.02632), we show the optimized training hyper-parameters and configs for quantization based image tokenizer. We also show how to scale the latent, vocab size etc. appropriately to achieve better reconstruction performance.
|
7 |
+
|
8 |
+
![dim-vocab-scaling.png](./https://github.com/HelmholtzAI-FZJ/flex_gen/raw/main/figures/dim-vocab-scaling.png)
|
9 |
+
|
10 |
+
We also show how to scaling the latent (and group) appropriately when pursuing high down-sample ratio in compression.
|
11 |
+
|
12 |
+
![spatial_scale.png](./https://github.com/HelmholtzAI-FZJ/flex_gen/raw/main/figures/spatial_scale.png)
|
13 |
+
|
14 |
+
The group scaling experiment of GSQ:
|
15 |
+
|
16 |
+
---
|
17 |
+
| **Models** | \( G $\times$ d \) | **rFID ↓** | **IS ↑** | **LPIPS ↓** | **PSNR ↑** | **SSIM ↑** | **Usage ↑** | **PPL ↑** |
|
18 |
+
|--------------------------------------|---------------------|------------|----------|-------------|------------|------------|-------------|-------------|
|
19 |
+
| **GSQ F8-D64** \( V=8K \) | \( 1 $\times$ 64 \) | 0.63 | 205 | 0.08 | 22.95 | 0.67 | 99.87% | 8,055 |
|
20 |
+
| | \( 2 $\times$ 32 \) | 0.32 | 220 | 0.05 | 25.42 | 0.76 | 100% | 8,157 |
|
21 |
+
| | \( 4 $\times$ 16 \) | 0.18 | 226 | 0.03 | 28.02 | 0.08 | 100% | 8,143 |
|
22 |
+
| | \( 16 $\times$ 4 \) | **0.03** | **233** | **0.004** | **34.61** | **0.91** | **99.98%** | **6,775** |
|
23 |
+
| **GSQ F16-D16** \( V=256K \) | \( 1 $\times$ 16 \) | 1.42 | 179 | 0.13 | 20.70 | 0.56 | 100% | 254,044 |
|
24 |
+
| | \( 2 $\times$ 8 \) | 0.82 | 199 | 0.09 | 22.20 | 0.63 | 100% | 257,273 |
|
25 |
+
| | \( 4 $\times$ 4 \) | 0.74 | 202 | 0.08 | 22.75 | 0.63 | 62.46% | 43,767 |
|
26 |
+
| | \( 8 $\times$ 2 \) | 0.50 | 211 | 0.06 | 23.62 | 0.66 | 46.83% | 22,181 |
|
27 |
+
| | \( 16 $\times$ 1 \) | 0.52 | 210 | 0.06 | 23.54 | 0.66 | 50.81% | 181 |
|
28 |
+
| | \( 16 $\times$ 1^* \) | 0.51 | 210 | 0.06 | 23.52 | 0.66 | 52.64% | 748 |
|
29 |
+
| **GSQ F32-D32** \( V=256K \) | \( 1 $\times$ 32 \) | 6.84 | 95 | 0.24 | 17.83 | 0.40 | 100% | 245,715 |
|
30 |
+
| | \( 2 $\times$ 16 \) | 3.31 | 139 | 0.18 | 19.01 | 0.47 | 100% | 253,369 |
|
31 |
+
| | \( 4 $\times$ 8 \) | 1.77 | 173 | 0.13 | 20.60 | 0.53 | 100% | 253,199 |
|
32 |
+
| | \( 8 $\times$ 4 \) | 1.67 | 176 | 0.12 | 20.88 | 0.54 | 59% | 40,307 |
|
33 |
+
| | \( 16 $\times$ 2 \) | 1.13 | 190 | 0.10 | 21.73 | 0.57 | 46% | 30,302 |
|
34 |
+
| | \( 32 $\times$ 1 \) | 1.21 | 187 | 0.10 | 21.64 | 0.57 | 54% | 247 |
|
35 |
+
---
|
36 |
+
|
37 |
+
|
38 |
+
## Use Pre-trained GSQ-Tokenizer
|
39 |
+
|
40 |
+
```python
|
41 |
+
from flex_gen import autoencoders
|
42 |
+
from timm import create_model
|
43 |
+
|
44 |
+
# ============= From HF's repo
|
45 |
+
model=create_model('flexTokenizer', pretrained=True,
|
46 |
+
repo_id='HelmholtzAI-FZJ/GSQ-F8-D8-V64k',)
|
47 |
+
|
48 |
+
# ============= From Local Checkpoint
|
49 |
+
model=create_model('flexTokenizer', pretrained=True,
|
50 |
+
path='PATH/your_checkpoint.pt', )
|
51 |
+
```
|
52 |
+
|
53 |
+
---
|
54 |
+
|
55 |
+
## Training your tokenizer
|
56 |
+
|
57 |
+
### Set-up Python Virtual Environment
|
58 |
+
|
59 |
+
```python
|
60 |
+
sh gen_env/setup.sh
|
61 |
+
|
62 |
+
source ./gen_env/activate.sh
|
63 |
+
|
64 |
+
#! This will run pip install to download all required lib
|
65 |
+
sh ./gen_env/install_requirements.sh
|
66 |
+
|
67 |
+
```
|
68 |
+
|
69 |
+
### Run Training
|
70 |
+
|
71 |
+
```python
|
72 |
+
# Single GPU
|
73 |
+
python -W ignore ./scripts/train_autoencoder.py
|
74 |
+
|
75 |
+
# Multi GPU
|
76 |
+
torchrun --nnodes=1 --nproc_per_node=4 ./scripts/train_autoencoder.py --config-file=PATH/config_name.yaml \
|
77 |
+
--output_dir=./logs_test/test opts train.num_train_steps=100 train_batch_size=16
|
78 |
+
```
|
79 |
+
|
80 |
+
### Run Evaluation
|
81 |
+
|
82 |
+
Add the checkpoint path that your want to test in `evaluation/run_tokenizer_eval.sh`
|
83 |
+
|
84 |
+
```bash
|
85 |
+
# For example
|
86 |
+
...
|
87 |
+
configs_of_training_lists=()
|
88 |
+
configs_of_training_lists=("logs_test/test/")
|
89 |
+
...
|
90 |
+
```
|
91 |
+
|
92 |
+
And run `sh evaluation/run_tokenizer_eval.sh` it will automatically scan `folder/model/eval_xxx.pth` for tokenizer evaluation
|
93 |
+
|
94 |
+
---
|
95 |
+
|
96 |
+
# **Citation**
|
97 |
+
|
98 |
+
```bash
|
99 |
+
@misc{GSQ,
|
100 |
+
title={Scaling Image Tokenizers with Grouped Spherical Quantization},
|
101 |
+
author={Jiangtao Wang and Zhen Qin and Yifan Zhang and Vincent Tao Hu and Björn Ommer and Rania Briq and Stefan Kesselheim},
|
102 |
+
year={2024},
|
103 |
+
eprint={2412.02632},
|
104 |
+
archivePrefix={arXiv},
|
105 |
+
primaryClass={cs.CV},
|
106 |
+
url={https://arxiv.org/abs/2412.02632},
|
107 |
+
}
|
108 |
+
```
|