Spaces:
Running
on
Zero
Running
on
Zero
File size: 11,847 Bytes
51ce47d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
# (NeurIPS 2024) Compact and Mighty - Image Tokenization with Only 32 Tokens for both Reconstruction and Generation!
<div align="center">
[![demo](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Online_Demo-blue)](https://huggingface.co/spaces/fun-research/TiTok)
[![Website](https://img.shields.io/badge/Project-Website-87CEEB)](https://yucornetto.github.io/projects/titok.html)
[![paper](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2406.07550)
</div>
We present a compact 1D tokenizer which can represent an image with as few as 32 discrete tokens. As a result, it leads to a substantial speed-up on the sampling process (e.g., **410 Γ faster** than DiT-XL/2) while obtaining a competitive generation quality.
<p>
<img src="assets/titok_teaser.png" alt="teaser" width=90% height=90%>
</p>
<p>
<img src="assets/speed_vs_perf.png" alt="teaser" width=90% height=90%>
</p>
## π Contributions
#### We introduce a novel 1D image tokenization framework that breaks grid constraints existing in 2D tokenization methods, leading to a much more flexible and compact image latent representation.
#### The proposed 1D tokenizer can tokenize a 256 Γ 256 image into as few as 32 discrete tokens, leading to a significant speed-up (hundreds times faster than diffusion models) in generation process, while maintaining state-of-the-art generation quality.
#### We conduct a series of experiments to probe the properties of rarely studied 1D image tokenization, paving the path towards compact latent space for efficient and effective image representation.
## Model Zoo
| Model | Link | FID |
| ------------- | ------------- | ------------- |
| TiTok-L-32 Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_l32_imagenet)| 2.21 (reconstruction) |
| TiTok-B-64 Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_b64_imagenet) | 1.70 (reconstruction) |
| TiTok-S-128 Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_s128_imagenet) | 1.71 (reconstruction) |
| TiTok-L-32 Generator | [checkpoint](https://huggingface.co/yucornetto/generator_titok_l32_imagenet) | 2.77 (generation) |
| TiTok-B-64 Generator | [checkpoint](https://huggingface.co/yucornetto/generator_titok_b64_imagenet) | 2.48 (generation) |
| TiTok-S-128 Generator | [checkpoint](https://huggingface.co/yucornetto/generator_titok_s128_imagenet) | 1.97 (generation) |
| TiTok-BL-64 VQ Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_bl64_vq8k_imagenet)| 2.06 (reconstruction) |
| TiTok-BL-128 VQ Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_bl128_vq8k_imagenet)| 1.49 (reconstruction) |
| TiTok-SL-256 VQ Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_sl256_vq8k_imagenet)| 1.03 (reconstruction) |
| TiTok-LL-32 VAE Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_ll32_vae_c16_imagenet)| 1.61 (reconstruction) |
| TiTok-BL-64 VAE Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_bl64_vae_c16_imagenet)| 1.25 (reconstruction) |
| TiTok-BL-128 VAE Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_bl128_vae_c16_imagenet)| 0.84 (reconstruction) |
Please note that these models are trained only on limited academic dataset ImageNet, and they are only for research purposes.
## Installation
```shell
pip3 install -r requirements.txt
```
## Get Started
```python
import torch
from PIL import Image
import numpy as np
import demo_util
from huggingface_hub import hf_hub_download
from modeling.maskgit import ImageBert
from modeling.titok import TiTok
# Choose one from ["tokenizer_titok_l32_imagenet", "tokenizer_titok_b64_imagenet",
# "tokenizer_titok_s128_imagenet", "tokenizer_titok_bl128_vae_c16_imagenet", tokenizer_titok_bl64_vae_c16_imagenet",
# "tokenizer_titok_ll32_vae_c16_imagenet", "tokenizer_titok_sl256_vq8k_imagenet", "tokenizer_titok_bl128_vq8k_imagenet",
# "tokenizer_titok_bl64_vq8k_imagenet",]
titok_tokenizer = TiTok.from_pretrained("yucornetto/tokenizer_titok_l32_imagenet")
titok_tokenizer.eval()
titok_tokenizer.requires_grad_(False)
titok_generator = ImageBert.from_pretrained("yucornetto/generator_titok_l32_imagenet")
titok_generator.eval()
titok_generator.requires_grad_(False)
# or alternatively, downloads from hf
# hf_hub_download(repo_id="fun-research/TiTok", filename="tokenizer_titok_l32.bin", local_dir="./")
# hf_hub_download(repo_id="fun-research/TiTok", filename="generator_titok_l32.bin", local_dir="./")
# load config
# config = demo_util.get_config("configs/infer/titok_l32.yaml")
# titok_tokenizer = demo_util.get_titok_tokenizer(config)
# titok_generator = demo_util.get_titok_generator(config)
device = "cuda"
titok_tokenizer = titok_tokenizer.to(device)
titok_generator = titok_generator.to(device)
# reconstruct an image. I.e., image -> 32 tokens -> image
img_path = "assets/ILSVRC2012_val_00010240.png"
image = torch.from_numpy(np.array(Image.open(img_path)).astype(np.float32)).permute(2, 0, 1).unsqueeze(0) / 255.0
# tokenization
if titok_tokenizer.quantize_mode == "vq":
encoded_tokens = titok_tokenizer.encode(image.to(device))[1]["min_encoding_indices"]
elif titok_tokenizer.quantize_mode == "vae":
posteriors = titok_tokenizer.encode(image.to(device))[1]
encoded_tokens = posteriors.sample()
else:
raise NotImplementedError
# image assets/ILSVRC2012_val_00010240.png is encoded into tokens tensor([[[ 887, 3979, 349, 720, 2809, 2743, 2101, 603, 2205, 1508, 1891, 4015, 1317, 2956, 3774, 2296, 484, 2612, 3472, 2330, 3140, 3113, 1056, 3779, 654, 2360, 1901, 2908, 2169, 953, 1326, 2598]]], device='cuda:0'), with shape torch.Size([1, 1, 32])
print(f"image {img_path} is encoded into tokens {encoded_tokens}, with shape {encoded_tokens.shape}")
# de-tokenization
reconstructed_image = titok_tokenizer.decode_tokens(encoded_tokens)
reconstructed_image = torch.clamp(reconstructed_image, 0.0, 1.0)
reconstructed_image = (reconstructed_image * 255.0).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()[0]
reconstructed_image = Image.fromarray(reconstructed_image).save("assets/ILSVRC2012_val_00010240_recon.png")
# generate an image
sample_labels = [torch.randint(0, 999, size=(1,)).item()] # random IN-1k class
generated_image = demo_util.sample_fn(
generator=titok_generator,
tokenizer=titok_tokenizer,
labels=sample_labels,
guidance_scale=4.5,
randomize_temperature=1.0,
num_sample_steps=8,
device=device
)
Image.fromarray(generated_image[0]).save(f"assets/generated_{sample_labels[0]}.png")
```
We also provide a [jupyter notebook](demo.ipynb) for a quick tutorial on reconstructing and generating images with TiTok-L-32.
We also support TiTok with [HuggingFace π€ Demo](https://huggingface.co/spaces/fun-research/TiTok)!
## Testing on ImageNet-1K Benchmark
We provide a [sampling script](./sample_imagenet_titok.py) for reproducing the generation results on ImageNet-1K benchmark.
```bash
# Prepare ADM evaluation script
git clone https://github.com/openai/guided-diffusion.git
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
```
```python
# Reproducing TiTok-L-32
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_titok.py config=configs/infer/titok_l32.yaml experiment.output_dir="titok_l_32"
# Run eval script. The result FID should be ~2.77
python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz titok_l_32.npz
# Reproducing TiTok-B-64
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_titok.py config=configs/infer/titok_b64.yaml experiment.output_dir="titok_b_64"
# Run eval script. The result FID should be ~2.48
python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz titok_b_64.npz
# Reproducing TiTok-S-128
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_titok.py config=configs/infer/titok_s128.yaml experiment.output_dir="titok_s_128"
# Run eval script. The result FID should be ~1.97
python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz titok_s_128.npz
```
## Training Preparation
We use [webdataset](https://github.com/webdataset/webdataset) format for data loading. To begin with, it is needed to convert the dataset into webdataset format. An example script to convert ImageNet to wds format is provided [here](./data/convert_imagenet_to_wds.py).
Furthermore, the stage1 training relies on a pre-trained MaskGIT-VQGAN to generate proxy codes as learning targets. You can convert the [official Jax weight](https://github.com/google-research/maskgit) to PyTorch version using [this script](https://github.com/huggingface/open-muse/blob/main/scripts/convert_maskgit_vqgan.py). Alternatively, we provided a converted version at [HuggingFace](https://huggingface.co/fun-research/TiTok/blob/main/maskgit-vqgan-imagenet-f16-256.bin) and [Google Drive](https://drive.google.com/file/d/1DjZqzJrUt2hwpmUPkjGSBTFEJcOkLY-Q/view?usp=sharing). The MaskGIT-VQGAN's weight will be automatically downloaded when you run the training script.
## Training
We provide example commands to train TiTok as follows:
```bash
# Training for TiTok-B64
# Stage 1
WANDB_MODE=offline accelerate launch --num_machines=1 --num_processes=8 --machine_rank=0 --main_process_ip=127.0.0.1 --main_process_port=9999 --same_network scripts/train_titok.py config=configs/training/stage1/titok_b64.yaml \
experiment.project="titok_b64_stage1" \
experiment.name="titok_b64_stage1_run1" \
experiment.output_dir="titok_b64_stage1_run1" \
training.per_gpu_batch_size=32
# Stage 2
WANDB_MODE=offline accelerate launch --num_machines=1 --num_processes=8 --machine_rank=0 --main_process_ip=127.0.0.1 --main_process_port=9999 --same_network scripts/train_titok.py config=configs/training/stage2/titok_b64.yaml \
experiment.project="titok_b64_stage2" \
experiment.name="titok_b64_stage2_run1" \
experiment.output_dir="titok_b64_stage2_run1" \
training.per_gpu_batch_size=32 \
experiment.init_weight=${PATH_TO_STAGE1_WEIGHT}
# Train Generator (TiTok-B64 as example)
WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=32 --machine_rank=${MACHINE_RANK} --main_process_ip=${ROOT_IP}--main_process_port=${ROOT_PORT} --same_network scripts/train_maskgit.py config=configs/training/generator/maskgit.yaml \
experiment.project="titok_generation" \
experiment.name="titok_b64_maskgit" \
experiment.output_dir="titok_b64_maskgit" \
experiment.tokenizer_checkpoint=${PATH_TO_STAGE1_or_STAGE2_WEIGHT}
```
You may remove the flag "WANDB_MODE=offline" to support online wandb logging, if you have configured it.
The config _titok_b64.yaml_ can be replaced with _titok_s128.yaml_ or _titok_l32.yaml_ for other TiTok variants.
## Visualizations
<p>
<img src="assets/recon_w_model_size_num_token.png" alt="teaser" width=90% height=90%>
</p>
<p>
<img src="assets/random_vis_l32.png" alt="teaser" width=90% height=90%>
</p>
## Citing
If you use our work in your research, please use the following BibTeX entry.
```BibTeX
@inproceedings{yu2024an,
author = {Qihang Yu and Mark Weber and Xueqing Deng and Xiaohui Shen and Daniel Cremers and Liang-Chieh Chen},
title = {An Image is Worth 32 Tokens for Reconstruction and Generation},
journal = {NeurIPS},
year = {2024}
}
```
## Acknowledgement
[MaskGIT](https://github.com/google-research/maskgit)
[Taming-Transformers](https://github.com/CompVis/taming-transformers)
[Open-MUSE](https://github.com/huggingface/open-muse)
[MUSE-Pytorch](https://github.com/baaivision/MUSE-Pytorch)
|