File size: 5,459 Bytes
164e6de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
# **Scaling Image Tokenizers with Grouped Spherical Quantization**
---
[Paper link](https://arxiv.org/abs/2412.02632) | [GITHUB REPO](https://github.com/HelmholtzAI-FZJ/flex_gen) [HF Checkpoints](https://huggingface.co/collections/HelmholtzAI-FZJ/grouped-spherical-quantization-674d6f9f548e472d0eaf179e)
In [GSQ](https://arxiv.org/abs/2412.02632), we show the optimized training hyper-parameters and configs for quantization based image tokenizer. We also show how to scale the latent, vocab size etc. appropriately to achieve better reconstruction performance.
![dim-vocab-scaling.png](./https://github.com/HelmholtzAI-FZJ/flex_gen/raw/main/figures/dim-vocab-scaling.png)
We also show how to scaling the latent (and group) appropriately when pursuing high down-sample ratio in compression.
![spatial_scale.png](./https://github.com/HelmholtzAI-FZJ/flex_gen/raw/main/figures/spatial_scale.png)
The group scaling experiment of GSQ:
---
| **Models** | \( G $\times$ d \) | **rFID ↓** | **IS ↑** | **LPIPS ↓** | **PSNR ↑** | **SSIM ↑** | **Usage ↑** | **PPL ↑** |
|--------------------------------------|---------------------|------------|----------|-------------|------------|------------|-------------|-------------|
| **GSQ F8-D64** \( V=8K \) | \( 1 $\times$ 64 \) | 0.63 | 205 | 0.08 | 22.95 | 0.67 | 99.87% | 8,055 |
| | \( 2 $\times$ 32 \) | 0.32 | 220 | 0.05 | 25.42 | 0.76 | 100% | 8,157 |
| | \( 4 $\times$ 16 \) | 0.18 | 226 | 0.03 | 28.02 | 0.08 | 100% | 8,143 |
| | \( 16 $\times$ 4 \) | **0.03** | **233** | **0.004** | **34.61** | **0.91** | **99.98%** | **6,775** |
| **GSQ F16-D16** \( V=256K \) | \( 1 $\times$ 16 \) | 1.42 | 179 | 0.13 | 20.70 | 0.56 | 100% | 254,044 |
| | \( 2 $\times$ 8 \) | 0.82 | 199 | 0.09 | 22.20 | 0.63 | 100% | 257,273 |
| | \( 4 $\times$ 4 \) | 0.74 | 202 | 0.08 | 22.75 | 0.63 | 62.46% | 43,767 |
| | \( 8 $\times$ 2 \) | 0.50 | 211 | 0.06 | 23.62 | 0.66 | 46.83% | 22,181 |
| | \( 16 $\times$ 1 \) | 0.52 | 210 | 0.06 | 23.54 | 0.66 | 50.81% | 181 |
| | \( 16 $\times$ 1^* \) | 0.51 | 210 | 0.06 | 23.52 | 0.66 | 52.64% | 748 |
| **GSQ F32-D32** \( V=256K \) | \( 1 $\times$ 32 \) | 6.84 | 95 | 0.24 | 17.83 | 0.40 | 100% | 245,715 |
| | \( 2 $\times$ 16 \) | 3.31 | 139 | 0.18 | 19.01 | 0.47 | 100% | 253,369 |
| | \( 4 $\times$ 8 \) | 1.77 | 173 | 0.13 | 20.60 | 0.53 | 100% | 253,199 |
| | \( 8 $\times$ 4 \) | 1.67 | 176 | 0.12 | 20.88 | 0.54 | 59% | 40,307 |
| | \( 16 $\times$ 2 \) | 1.13 | 190 | 0.10 | 21.73 | 0.57 | 46% | 30,302 |
| | \( 32 $\times$ 1 \) | 1.21 | 187 | 0.10 | 21.64 | 0.57 | 54% | 247 |
---
## Use Pre-trained GSQ-Tokenizer
```python
from flex_gen import autoencoders
from timm import create_model
# ============= From HF's repo
model=create_model('flexTokenizer', pretrained=True,
repo_id='HelmholtzAI-FZJ/GSQ-F8-D8-V64k',)
# ============= From Local Checkpoint
model=create_model('flexTokenizer', pretrained=True,
path='PATH/your_checkpoint.pt', )
```
---
## Training your tokenizer
### Set-up Python Virtual Environment
```python
sh gen_env/setup.sh
source ./gen_env/activate.sh
#! This will run pip install to download all required lib
sh ./gen_env/install_requirements.sh
```
### Run Training
```python
# Single GPU
python -W ignore ./scripts/train_autoencoder.py
# Multi GPU
torchrun --nnodes=1 --nproc_per_node=4 ./scripts/train_autoencoder.py --config-file=PATH/config_name.yaml \
--output_dir=./logs_test/test opts train.num_train_steps=100 train_batch_size=16
```
### Run Evaluation
Add the checkpoint path that your want to test in `evaluation/run_tokenizer_eval.sh`
```bash
# For example
...
configs_of_training_lists=()
configs_of_training_lists=("logs_test/test/")
...
```
And run `sh evaluation/run_tokenizer_eval.sh` it will automatically scan `folder/model/eval_xxx.pth` for tokenizer evaluation
---
# **Citation**
```bash
@misc{GSQ,
title={Scaling Image Tokenizers with Grouped Spherical Quantization},
author={Jiangtao Wang and Zhen Qin and Yifan Zhang and Vincent Tao Hu and Björn Ommer and Rania Briq and Stefan Kesselheim},
year={2024},
eprint={2412.02632},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2412.02632},
}
``` |