Spaces:
Running
on
Zero
Running
on
Zero
"""This file contains the model definition of TiTok. | |
Copyright (2024) Bytedance Ltd. and/or its affiliates | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
""" | |
import torch | |
import torch.nn as nn | |
from einops import rearrange | |
from modeling.modules.base_model import BaseModel | |
from modeling.modules.blocks import TiTokEncoder, TiTokDecoder | |
from modeling.quantizer.quantizer import VectorQuantizer, DiagonalGaussianDistribution | |
from modeling.modules.maskgit_vqgan import Encoder as Pixel_Eecoder | |
from modeling.modules.maskgit_vqgan import Decoder as Pixel_Decoder | |
from modeling.modules.maskgit_vqgan import VectorQuantizer as Pixel_Quantizer | |
import json | |
from omegaconf import OmegaConf | |
from pathlib import Path | |
from huggingface_hub import PyTorchModelHubMixin | |
class PretrainedTokenizer(nn.Module): | |
def __init__(self, pretrained_weight): | |
super().__init__() | |
conf = OmegaConf.create( | |
{"channel_mult": [1, 1, 2, 2, 4], | |
"num_resolutions": 5, | |
"dropout": 0.0, | |
"hidden_channels": 128, | |
"num_channels": 3, | |
"num_res_blocks": 2, | |
"resolution": 256, | |
"z_channels": 256}) | |
self.encoder = Pixel_Eecoder(conf) | |
self.decoder = Pixel_Decoder(conf) | |
self.quantize = Pixel_Quantizer( | |
num_embeddings=1024, embedding_dim=256, commitment_cost=0.25) | |
# Load pretrained weights | |
self.load_state_dict(torch.load(pretrained_weight, map_location=torch.device("cpu")), strict=True) | |
self.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def encode(self, x): | |
hidden_states = self.encoder(x) | |
quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states) | |
return codebook_indices.detach() | |
def decode(self, codes): | |
quantized_states = self.quantize.get_codebook_entry(codes) | |
rec_images = self.decoder(quantized_states) | |
rec_images = torch.clamp(rec_images, 0.0, 1.0) | |
return rec_images.detach() | |
def decode_tokens(self, codes): | |
return self.decode(codes) | |
class TiTok(BaseModel, PyTorchModelHubMixin, tags=["arxiv:2406.07550", "image-tokenization"], repo_url="https://github.com/bytedance/1d-tokenizer", license="apache-2.0"): | |
def __init__(self, config): | |
if isinstance(config, dict): | |
config = OmegaConf.create(config) | |
super().__init__() | |
self.config = config | |
# This should be False for stage1 and True for stage2. | |
self.finetune_decoder = config.model.vq_model.get("finetune_decoder", True) | |
self.quantize_mode = config.model.vq_model.get("quantize_mode", "vq") | |
if self.quantize_mode not in ["vq", "vae"]: | |
raise ValueError(f"Unsupported quantize mode {self.quantize_mode}.") | |
if self.finetune_decoder and self.quantize_mode not in ["vq"]: | |
raise ValueError("Only supprot finetune_decoder with vq quantization for now.") | |
self.encoder = TiTokEncoder(config) | |
self.decoder = TiTokDecoder(config) | |
self.num_latent_tokens = config.model.vq_model.num_latent_tokens | |
scale = self.encoder.width ** -0.5 | |
self.latent_tokens = nn.Parameter( | |
scale * torch.randn(self.num_latent_tokens, self.encoder.width)) | |
self.apply(self._init_weights) | |
if self.quantize_mode == "vq": | |
self.quantize = VectorQuantizer( | |
codebook_size=config.model.vq_model.codebook_size, | |
token_size=config.model.vq_model.token_size, | |
commitment_cost=config.model.vq_model.commitment_cost, | |
use_l2_norm=config.model.vq_model.use_l2_norm,) | |
elif self.quantize_mode == "vae": | |
self.quantize = DiagonalGaussianDistribution | |
else: | |
raise NotImplementedError | |
if self.finetune_decoder: | |
# Freeze encoder/quantizer/latent tokens | |
self.latent_tokens.requires_grad_(False) | |
self.encoder.eval() | |
self.encoder.requires_grad_(False) | |
self.quantize.eval() | |
self.quantize.requires_grad_(False) | |
# Include MaskGiT-VQGAN's quantizer and decoder | |
self.pixel_quantize = Pixel_Quantizer( | |
num_embeddings=1024, embedding_dim=256, commitment_cost=0.25) | |
self.pixel_decoder = Pixel_Decoder(OmegaConf.create( | |
{"channel_mult": [1, 1, 2, 2, 4], | |
"num_resolutions": 5, | |
"dropout": 0.0, | |
"hidden_channels": 128, | |
"num_channels": 3, | |
"num_res_blocks": 2, | |
"resolution": 256, | |
"z_channels": 256})) | |
def _save_pretrained(self, save_directory: Path) -> None: | |
"""Save weights and config to a local directory.""" | |
# Assume 'self.config' is your DictConfig object | |
# Convert to a regular dictionary | |
dict_config = OmegaConf.to_container(self.config) | |
# Save as JSON | |
file_path = Path(save_directory) / "config.json" | |
with open(file_path, 'w') as json_file: | |
json.dump(dict_config, json_file, indent=4) | |
super()._save_pretrained(save_directory) | |
def _init_weights(self, module): | |
""" Initialize the weights. | |
:param: | |
module -> torch.nn.Module: module to initialize | |
""" | |
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d) or isinstance(module, nn.Conv2d): | |
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.Embedding): | |
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02) | |
elif isinstance(module, nn.LayerNorm): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
def encode(self, x): | |
if self.finetune_decoder: | |
with torch.no_grad(): | |
self.encoder.eval() | |
self.quantize.eval() | |
z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens) | |
z_quantized, result_dict = self.quantize(z) | |
result_dict["quantizer_loss"] *= 0 | |
result_dict["commitment_loss"] *= 0 | |
result_dict["codebook_loss"] *= 0 | |
else: | |
z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens) | |
if self.quantize_mode == "vq": | |
z_quantized, result_dict = self.quantize(z) | |
elif self.quantize_mode == "vae": | |
posteriors = self.quantize(z) | |
z_quantized = posteriors.sample() | |
result_dict = posteriors | |
return z_quantized, result_dict | |
def decode(self, z_quantized): | |
decoded = self.decoder(z_quantized) | |
if self.finetune_decoder: | |
quantized_states = torch.einsum( | |
'nchw,cd->ndhw', decoded.softmax(1), | |
self.pixel_quantize.embedding.weight) | |
decoded = self.pixel_decoder(quantized_states) | |
return decoded | |
def decode_tokens(self, tokens): | |
if self.quantize_mode == "vq": | |
tokens = tokens.squeeze(1) | |
batch, seq_len = tokens.shape # B x N | |
z_quantized = self.quantize.get_codebook_entry( | |
tokens.reshape(-1)).reshape(batch, 1, seq_len, -1) | |
z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous() | |
elif self.quantize_mode == "vae": | |
z_quantized = tokens | |
decoded = self.decode(z_quantized) | |
return decoded | |
def forward(self, x): | |
z_quantized, result_dict = self.encode(x) | |
decoded = self.decode(z_quantized) | |
return decoded, result_dict | |