# -------------------------------------------------------- # BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers (https://arxiv.org/abs/2208.06366) # Github source: https://github.com/microsoft/unilm/tree/master/beitv2 # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # By Zhiliang Peng # Based on VQGAN code bases # https://github.com/CompVis/taming-transformers # --------------------------------------------------------' import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as distributed from einops import rearrange, repeat def l2norm(t): return F.normalize(t, p = 2, dim = -1) def ema_inplace(moving_avg, new, decay): moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) def sample_vectors(samples, num): num_samples, device = samples.shape[0], samples.device if num_samples >= num: indices = torch.randperm(num_samples, device = device)[:num] else: indices = torch.randint(0, num_samples, (num,), device = device) return samples[indices] def kmeans(samples, num_clusters, num_iters = 10, use_cosine_sim = False): dim, dtype, device = samples.shape[-1], samples.dtype, samples.device means = sample_vectors(samples, num_clusters) for _ in range(num_iters): if use_cosine_sim: dists = samples @ means.t() else: diffs = rearrange(samples, 'n d -> n () d') \ - rearrange(means, 'c d -> () c d') dists = -(diffs ** 2).sum(dim = -1) buckets = dists.max(dim = -1).indices bins = torch.bincount(buckets, minlength = num_clusters) zero_mask = bins == 0 bins_min_clamped = bins.masked_fill(zero_mask, 1) new_means = buckets.new_zeros(num_clusters, dim, dtype = dtype) new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d = dim), samples) new_means = new_means / bins_min_clamped[..., None] if use_cosine_sim: new_means = l2norm(new_means) means = torch.where(zero_mask[..., None], means, new_means) return means, bins class EmbeddingEMA(nn.Module): def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''): super().__init__() self.num_tokens = num_tokens self.codebook_dim = codebook_dim self.decay = decay self.eps = eps if codebook_init_path == '': if not kmeans_init: weight = torch.randn(num_tokens, codebook_dim) weight = l2norm(weight) else: weight = torch.zeros(num_tokens, codebook_dim) self.register_buffer('initted', torch.Tensor([not kmeans_init])) else: print(f"load init codebook weight from {codebook_init_path}") codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu') weight = codebook_ckpt_weight.clone() self.register_buffer('initted', torch.Tensor([True])) self.weight = nn.Parameter(weight, requires_grad = False) self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False) self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False) # self.register_buffer('initted', torch.Tensor([not kmeans_init])) self.update = True @torch.jit.ignore def init_embed_(self, data): if self.initted: return print("Performing Kemans init for codebook") embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim = True) self.weight.data.copy_(embed) self.cluster_size.data.copy_(cluster_size) self.initted.data.copy_(torch.Tensor([True])) def forward(self, embed_id): return F.embedding(embed_id, self.weight) def cluster_size_ema_update(self, new_cluster_size): self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) def embed_avg_ema_update(self, new_embed_avg): self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) def weight_update(self, num_tokens): n = self.cluster_size.sum() smoothed_cluster_size = ( (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n ) #normalize embedding average with smoothed cluster size embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1)) self.weight.data.copy_(embed_normalized) def norm_ema_inplace(moving_avg, new, decay): moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) moving_avg.data.copy_(l2norm(moving_avg.data)) class NormEMAVectorQuantizer(nn.Module): def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, statistic_code_usage=True, kmeans_init=False, codebook_init_path=''): super().__init__() self.codebook_dim = embedding_dim self.num_tokens = n_embed self.beta = beta self.decay = decay # learnable = True if orthogonal_reg_weight > 0 else False self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path) self.statistic_code_usage = statistic_code_usage if statistic_code_usage: self.register_buffer('cluster_size', torch.zeros(n_embed)) if distributed.is_available() and distributed.is_initialized(): print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!") self.all_reduce_fn = distributed.all_reduce else: self.all_reduce_fn = nn.Identity() def reset_cluster_size(self, device): if self.statistic_code_usage: self.register_buffer('cluster_size', torch.zeros(self.num_tokens)) self.cluster_size = self.cluster_size.to(device) def forward(self, z): # reshape z -> (batch, height, width, channel) and flatten #z, 'b c h w -> b h w c' z = rearrange(z, 'b c h w -> b h w c') z = l2norm(z) z_flattened = z.reshape(-1, self.codebook_dim) self.embedding.init_embed_(z_flattened) d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ self.embedding.weight.pow(2).sum(dim=1) - 2 * \ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' encoding_indices = torch.argmin(d, dim=1) z_q = self.embedding(encoding_indices).view(z.shape) encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) if not self.training: with torch.no_grad(): cluster_size = encodings.sum(0) self.all_reduce_fn(cluster_size) ema_inplace(self.cluster_size, cluster_size, self.decay) if self.training and self.embedding.update: #EMA cluster size bins = encodings.sum(0) self.all_reduce_fn(bins) # self.embedding.cluster_size_ema_update(bins) ema_inplace(self.cluster_size, bins, self.decay) zero_mask = (bins == 0) bins = bins.masked_fill(zero_mask, 1.) embed_sum = z_flattened.t() @ encodings self.all_reduce_fn(embed_sum) embed_normalized = (embed_sum / bins.unsqueeze(0)).t() embed_normalized = l2norm(embed_normalized) embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight, embed_normalized) norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay) # compute loss for embedding loss = self.beta * F.mse_loss(z_q.detach(), z) # preserve gradients z_q = z + (z_q - z).detach() # reshape back to match original input shape #z_q, 'b h w c -> b c h w' z_q = rearrange(z_q, 'b h w c -> b c h w') return z_q, loss, encoding_indices