Terra / modeling_lfq_tokenizer.py
koukyo1994's picture
upload LFQ implementation
07155e5 verified
"""
Hugging Face compatible implementation of Open-MAGVIt2
Code reference: https://github.com/TencentARC/Open-MAGVIT2
"""
from math import log2, ceil
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce, pack, unpack
from torch import einsum
from torch.nn import Module
from transformers import PreTrainedModel
from .configuration_lfq_tokenizer import LFQTokenizerConfig
def swish(x):
# swish
return x * torch.sigmoid(x)
class ResBlock(nn.Module):
def __init__(self,
in_filters,
out_filters,
use_conv_shortcut = False
) -> None:
super().__init__()
self.in_filters = in_filters
self.out_filters = out_filters
self.use_conv_shortcut = use_conv_shortcut
self.norm1 = nn.GroupNorm(32, in_filters, eps=1e-6)
self.norm2 = nn.GroupNorm(32, out_filters, eps=1e-6)
self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False)
self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False)
if in_filters != out_filters:
if self.use_conv_shortcut:
self.conv_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False)
else:
self.nin_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), padding=0, bias=False)
def forward(self, x, **kwargs):
residual = x
x = self.norm1(x)
x = swish(x)
x = self.conv1(x)
x = self.norm2(x)
x = swish(x)
x = self.conv2(x)
if self.in_filters != self.out_filters:
if self.use_conv_shortcut:
residual = self.conv_shortcut(residual)
else:
residual = self.nin_shortcut(residual)
return x + residual
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4)):
super().__init__()
self.in_channels = in_channels
self.z_channels = z_channels
self.num_res_blocks = num_res_blocks
self.num_blocks = len(ch_mult)
self.conv_in = nn.Conv2d(in_channels,
ch,
kernel_size=(3, 3),
padding=1,
bias=False
)
## construct the model
self.down = nn.ModuleList()
in_ch_mult = (1,)+tuple(ch_mult)
for i_level in range(self.num_blocks):
block = nn.ModuleList()
block_in = ch*in_ch_mult[i_level] #[1, 1, 2, 2, 4]
block_out = ch*ch_mult[i_level] #[1, 2, 2, 4]
for _ in range(self.num_res_blocks):
block.append(ResBlock(block_in, block_out))
block_in = block_out
down = nn.Module()
down.block = block
if i_level < self.num_blocks - 1:
down.downsample = nn.Conv2d(block_out, block_out, kernel_size=(3, 3), stride=(2, 2), padding=1)
self.down.append(down)
### mid
self.mid_block = nn.ModuleList()
for res_idx in range(self.num_res_blocks):
self.mid_block.append(ResBlock(block_in, block_in))
### end
self.norm_out = nn.GroupNorm(32, block_out, eps=1e-6)
self.conv_out = nn.Conv2d(block_out, z_channels, kernel_size=(1, 1))
def forward(self, x):
## down
x = self.conv_in(x)
for i_level in range(self.num_blocks):
for i_block in range(self.num_res_blocks):
x = self.down[i_level].block[i_block](x)
if i_level < self.num_blocks - 1:
x = self.down[i_level].downsample(x)
## mid
for res in range(self.num_res_blocks):
x = self.mid_block[res](x)
x = self.norm_out(x)
x = swish(x)
x = self.conv_out(x)
return x
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4)) -> None:
super().__init__()
self.ch = ch
self.num_blocks = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_channels
block_in = ch*ch_mult[self.num_blocks-1]
self.conv_in = nn.Conv2d(
z_channels, block_in, kernel_size=(3, 3), padding=1, bias=True
)
self.mid_block = nn.ModuleList()
for res_idx in range(self.num_res_blocks):
self.mid_block.append(ResBlock(block_in, block_in))
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_blocks)):
block = nn.ModuleList()
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResBlock(block_in, block_out))
block_in = block_out
up = nn.Module()
up.block = block
if i_level > 0:
up.upsample = Upsampler(block_in)
self.up.insert(0, up)
self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=(3, 3), padding=1)
def forward(self, z):
z = self.conv_in(z)
## mid
for res in range(self.num_res_blocks):
z = self.mid_block[res](z)
## upsample
for i_level in reversed(range(self.num_blocks)):
for i_block in range(self.num_res_blocks):
z = self.up[i_level].block[i_block](z)
if i_level > 0:
z = self.up[i_level].upsample(z)
z = self.norm_out(z)
z = swish(z)
z = self.conv_out(z)
return z
def depth_to_space(x: torch.Tensor, block_size: int) -> torch.Tensor:
""" Depth-to-Space DCR mode (depth-column-row) core implementation.
Args:
x (torch.Tensor): input tensor. The channels-first (*CHW) layout is supported.
block_size (int): block side size
"""
# check inputs
if x.dim() < 3:
raise ValueError(
f"Expecting a channels-first (*CHW) tensor of at least 3 dimensions"
)
c, h, w = x.shape[-3:]
s = block_size**2
if c % s != 0:
raise ValueError(
f"Expecting a channels-first (*CHW) tensor with C divisible by {s}, but got C={c} channels"
)
outer_dims = x.shape[:-3]
# splitting two additional dimensions from the channel dimension
x = x.view(-1, block_size, block_size, c // s, h, w)
# putting the two new dimensions along H and W
x = x.permute(0, 3, 4, 1, 5, 2)
# merging the two new dimensions with H and W
x = x.contiguous().view(*outer_dims, c // s, h * block_size,
w * block_size)
return x
class Upsampler(nn.Module):
def __init__(
self,
dim,
dim_out = None
):
super().__init__()
dim_out = dim * 4
self.conv1 = nn.Conv2d(dim, dim_out, (3, 3), padding=1)
self.depth2space = depth_to_space
def forward(self, x):
"""
input_image: [B C H W]
"""
out = self.conv1(x)
out = self.depth2space(out, block_size=2)
return out
class AdaptiveGroupNorm(nn.Module):
def __init__(self, z_channel, in_filters, num_groups=32, eps=1e-6):
super().__init__()
self.gn = nn.GroupNorm(num_groups=32, num_channels=in_filters, eps=eps, affine=False)
# self.lin = nn.Linear(z_channels, in_filters * 2)
self.gamma = nn.Linear(z_channel, in_filters)
self.beta = nn.Linear(z_channel, in_filters)
self.eps = eps
def forward(self, x, quantizer):
B, C, _, _ = x.shape
# quantizer = F.adaptive_avg_pool2d(quantizer, (1, 1))
### calcuate var for scale
scale = rearrange(quantizer, "b c h w -> b c (h w)")
scale = scale.var(dim=-1) + self.eps #not unbias
scale = scale.sqrt()
scale = self.gamma(scale).view(B, C, 1, 1)
### calculate mean for bias
bias = rearrange(quantizer, "b c h w -> b c (h w)")
bias = bias.mean(dim=-1)
bias = self.beta(bias).view(B, C, 1, 1)
x = self.gn(x)
x = scale * x + bias
return x
# constants
LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'codebook_entropy', 'commitment', 'avg_probs'])
# helper functions
def exists(v):
return v is not None
def default(*args):
for arg in args:
if exists(arg):
return arg() if callable(arg) else arg
return None
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# entropy
def entropy(prob):
return (-prob * torch.log(prob + 1e-5)).sum(dim=-1)
# class
def mult_along_first_dims(x, y):
"""
returns x * y elementwise along the leading dimensions of y
"""
ndim_to_expand = x.ndim - y.ndim
for _ in range(ndim_to_expand):
y = y.unsqueeze(-1)
return x * y
def masked_mean(x, m):
"""
takes the mean of the elements of x that are not masked
the mean is taken along the shared leading dims of m
equivalent to: x[m].mean(tuple(range(m.ndim)))
The benefit of using masked_mean rather than using
tensor indexing is that masked_mean is much faster
for torch-compile on batches.
The drawback is larger floating point errors
"""
x = mult_along_first_dims(x, m)
x = x / m.sum()
return x.sum(tuple(range(m.ndim)))
def entropy_loss(
logits,
mask=None,
temperature=0.01,
sample_minimization_weight=1.0,
batch_maximization_weight=1.0,
eps=1e-5,
):
"""
Entropy loss of unnormalized logits
logits: Affinities are over the last dimension
https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L279
LANGUAGE MODEL BEATS DIFFUSION — TOKENIZER IS KEY TO VISUAL GENERATION (2024)
"""
probs = F.softmax(logits / temperature, -1)
log_probs = F.log_softmax(logits / temperature + eps, -1)
if mask is not None:
avg_probs = masked_mean(probs, mask)
else:
avg_probs = reduce(probs, "... D -> D", "mean")
avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps))
sample_entropy = -torch.sum(probs * log_probs, -1)
if mask is not None:
sample_entropy = masked_mean(sample_entropy, mask).mean()
else:
sample_entropy = torch.mean(sample_entropy)
loss = (sample_minimization_weight * sample_entropy) - (
batch_maximization_weight * avg_entropy
)
return sample_entropy, avg_entropy, loss
class LFQ(Module):
def __init__(
self,
*,
dim = None,
codebook_size = None,
num_codebooks = 1,
sample_minimization_weight=1.0,
batch_maximization_weight=1.0,
token_factorization = False,
):
super().__init__()
# some assert validations
assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
self.codebook_size = default(codebook_size, lambda: 2 ** dim)
self.codebook_dim = int(log2(codebook_size))
codebook_dims = self.codebook_dim * num_codebooks
dim = default(dim, codebook_dims)
has_projections = dim != codebook_dims
self.has_projections = has_projections
self.dim = dim
self.codebook_dim = self.codebook_dim
self.num_codebooks = num_codebooks
# for entropy loss
self.sample_minimization_weight = sample_minimization_weight
self.batch_maximization_weight = batch_maximization_weight
# for no auxiliary loss, during inference
self.token_factorization = token_factorization ## only utilized in second stage
if not self.token_factorization: #for first stage model
self.register_buffer('mask', 2 ** torch.arange(self.codebook_dim - 1, -1, -1), persistent=False)
else:
k = self.codebook_dim // 2
self.register_buffer("mask", 2 ** torch.arange(k - 1, -1, -1), persistent=False)
self.register_buffer('zero', torch.tensor(0.), persistent = False)
# codes
all_codes = torch.arange(codebook_size)
bits = self.indices_to_bits(all_codes)
codebook = bits * 2.0 - 1.0
self.register_buffer('codebook', codebook, persistent = False)
@property
def dtype(self):
return self.codebook.dtype
def indices_to_bits(self, x):
"""
x: long tensor of indices for constructing codebook, but actually not utilized in all the experiments.
returns big endian bits
"""
mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long)
# x is now big endian bits, the last dimension being the bits
x = (x.unsqueeze(-1) & mask) != 0
return x
def get_codebook_entry(self, x, bhwc):
if self.token_factorization:
k = self.codebook_dim // 2
mask = 2 ** torch.arange(k - 1, -1, -1, device=x.device, dtype=torch.long)
else:
mask = 2 ** torch.arange(self.codebook_dim-1, -1, -1, device=x.device, dtype=torch.long)
x = (x.unsqueeze(-1) & mask) != 0
x = x * 2.0 - 1.0 #back to the float
## scale back to the desired shape
b, h, w, c = bhwc
x = rearrange(x, "b (h w) c -> b h w c", h=h, w=w, c=c)
x = rearrange(x, "b h w c -> b c h w")
return x
def bits_to_indices(self, bits):
"""
bits: bool tensor of big endian bits, where the last dimension is the bit dimension
returns indices, which are long integers from 0 to self.codebook_size
"""
assert bits.shape[-1] == self.codebook_dim
indices = 2 ** torch.arange(
0,
self.codebook_dim,
1,
dtype=torch.long,
device=bits.device,
)
return (bits * indices).sum(-1)
def decode(self, x):
"""
x: ... NH
where NH is number of codebook heads
A longtensor of codebook indices, containing values from
0 to self.codebook_size
"""
x = self.indices_to_bits(x)
# to some sort of float
x = x.to(self.dtype)
# -1 or 1
x = x * 2 - 1
x = rearrange(x, "... NC Z-> ... (NC Z)")
return x
def forward(
self,
x,
return_loss_breakdown = False,
mask = None,
return_loss = True,
):
"""
einstein notation
b - batch
n - sequence (or flattened spatial dimensions)
d - feature dimension, which is also log2(codebook size)
c - number of codebook dim
"""
x = rearrange(x, 'b d ... -> b ... d')
x, ps = pack_one(x, 'b * d')
# split out number of codebooks
x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
codebook_value = torch.Tensor([1.0]).to(device=x.device, dtype=x.dtype)
quantized = torch.where(x > 0, codebook_value, -codebook_value) # higher than 0 filled
# calculate indices
if self.token_factorization:
k = self.codebook_dim // 2
indices_pre = reduce((quantized[..., :k] > 0).int() * self.mask.int(), "b n c d -> b n c", "sum")
indices_post = reduce((quantized[..., k:] > 0).int() * self.mask.int(), "b n c d -> b n c", "sum")
# indices_post = 2**k + indices_post #shifter to the 1024
else:
indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
# entropy aux loss
if self.training and return_loss:
logits = 2 * einsum('... i d, j d -> ... i j', x, self.codebook)
# the same as euclidean distance up to a constant
per_sample_entropy, codebook_entropy, entropy_aux_loss = entropy_loss(
logits = logits,
sample_minimization_weight = self.sample_minimization_weight,
batch_maximization_weight = self.batch_maximization_weight
)
avg_probs = self.zero
else:
## calculate the codebook_entropy needed for one batch evaluation
#------------------------------------------------------------------
# logits = 2 * einsum('... i d, j d -> ... i j', x, self.codebook)
# probs = F.softmax(logits / 0.01, -1)
# avg_probs = reduce(probs, "b n c d -> b d", "mean")
# avg_probs = torch.sum(avg_probs, 0) #batch dimension
#-------------------------------------------------------------------
# if not training, just return dummy 0
per_sample_entropy = codebook_entropy = self.zero
entropy_aux_loss = self.zero
avg_probs = self.zero
# commit loss
if self.training:
commit_loss = F.mse_loss(x, quantized.detach(), reduction = 'none')
if exists(mask):
commit_loss = commit_loss[mask]
commit_loss = commit_loss.mean()
else:
commit_loss = self.zero
# use straight-through gradients (optionally with custom activation fn) if training
quantized = x + (quantized - x).detach() #transfer to quantized
# merge back codebook dim
quantized = rearrange(quantized, 'b n c d -> b n (c d)')
# reconstitute image or video dimensions
quantized = unpack_one(quantized, ps, 'b * d')
quantized = rearrange(quantized, 'b ... d -> b d ...')
if self.token_factorization:
indices_pre = unpack_one(indices_pre, ps, "b * c")
indices_post = unpack_one(indices_post, ps, "b * c")
indices_pre = indices_pre.flatten()
indices_post = indices_post.flatten()
indices = (indices_pre, indices_post)
else:
indices = unpack_one(indices, ps, 'b * c')
indices = indices.flatten()
ret = (quantized, entropy_aux_loss, indices)
if not return_loss_breakdown:
return ret
return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss, avg_probs)
class LFQTokenizer(PreTrainedModel):
config_class = LFQTokenizerConfig
def __init__(self, config: LFQTokenizerConfig):
super().__init__(config)
self.encoder = Encoder(**config.encoder_decoder_config)
self.decoder = Decoder(**config.encoder_decoder_config)
self.quantize = LFQ(**config.quantizer_config)
def encode(self, x):
h = self.encoder(x)
(quant, emb_loss, info), loss_breakdown = self.quantize(h, return_loss_breakdown=True)
return quant, emb_loss, info, loss_breakdown
def decode(self, quant):
return self.decoder(quant)
def forward(self, input):
quant, diff, _, loss_breakdown = self.encode(input)
dec = self.decoder(quant)
return dec, diff, loss_breakdown
def tokenize(self, input):
_, _, tokens, _ = self.encode(input)
return tokens
def get_last_layer(self):
return self.decoder.conv_out.weight
def decode_tokens(self, tokens, shape: tuple):
if self.quantize.token_factorization:
tokens_pre, tokens_post = tokens[0], tokens[1]
quant_pre = self.quantize.get_codebook_entry(tokens_pre, shape)
quant_post = self.quantize.get_codebook_entry(tokens_post, shape)
quant = torch.concat([quant_pre, quant_post], dim=1)
return self.decode(quant)
else:
if tokens.ndim == 1:
batch_size = shape[0]
tokens = tokens.view(batch_size, -1)
quant = self.quantize.get_codebook_entry(tokens, shape)
return self.decode(quant)