Spaces:
Running
Running
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange, repeat, reduce, pack, unpack | |
# from vector_quantize_pytorch import ResidualVQ | |
#Borrow from vector_quantize_pytorch | |
def log(t, eps = 1e-20): | |
return torch.log(t.clamp(min = eps)) | |
def gumbel_noise(t): | |
noise = torch.zeros_like(t).uniform_(0, 1) | |
return -log(-log(noise)) | |
def gumbel_sample( | |
logits, | |
temperature = 1., | |
stochastic = False, | |
dim = -1, | |
training = True | |
): | |
if training and stochastic and temperature > 0: | |
sampling_logits = (logits / temperature) + gumbel_noise(logits) | |
else: | |
sampling_logits = logits | |
ind = sampling_logits.argmax(dim = dim) | |
return ind | |
class QuantizeEMAReset(nn.Module): | |
def __init__(self, nb_code, code_dim, args): | |
super(QuantizeEMAReset, self).__init__() | |
self.nb_code = nb_code | |
self.code_dim = code_dim | |
self.mu = args.mu ##TO_DO | |
self.reset_codebook() | |
def reset_codebook(self): | |
self.init = False | |
self.code_sum = None | |
self.code_count = None | |
self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim, requires_grad=False)) | |
def _tile(self, x): | |
nb_code_x, code_dim = x.shape | |
if nb_code_x < self.nb_code: | |
n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x | |
std = 0.01 / np.sqrt(code_dim) | |
out = x.repeat(n_repeats, 1) | |
out = out + torch.randn_like(out) * std | |
else: | |
out = x | |
return out | |
def init_codebook(self, x): | |
out = self._tile(x) | |
self.codebook = out[:self.nb_code] | |
self.code_sum = self.codebook.clone() | |
self.code_count = torch.ones(self.nb_code, device=self.codebook.device) | |
self.init = True | |
def quantize(self, x, sample_codebook_temp=0.): | |
# N X C -> C X N | |
k_w = self.codebook.t() | |
# x: NT X C | |
# NT X N | |
distance = torch.sum(x ** 2, dim=-1, keepdim=True) - \ | |
2 * torch.matmul(x, k_w) + \ | |
torch.sum(k_w ** 2, dim=0, keepdim=True) # (N * L, b) | |
# code_idx = torch.argmin(distance, dim=-1) | |
code_idx = gumbel_sample(-distance, dim = -1, temperature = sample_codebook_temp, stochastic=True, training = self.training) | |
return code_idx | |
def dequantize(self, code_idx): | |
x = F.embedding(code_idx, self.codebook) | |
return x | |
def get_codebook_entry(self, indices): | |
return self.dequantize(indices).permute(0, 2, 1) | |
def compute_perplexity(self, code_idx): | |
# Calculate new centres | |
code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L | |
code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) | |
code_count = code_onehot.sum(dim=-1) # nb_code | |
prob = code_count / torch.sum(code_count) | |
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) | |
return perplexity | |
def update_codebook(self, x, code_idx): | |
code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L | |
code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) | |
code_sum = torch.matmul(code_onehot, x) # nb_code, c | |
code_count = code_onehot.sum(dim=-1) # nb_code | |
out = self._tile(x) | |
code_rand = out[:self.nb_code] | |
# Update centres | |
self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum | |
self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count | |
usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() | |
code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) | |
self.codebook = usage * code_update + (1-usage) * code_rand | |
prob = code_count / torch.sum(code_count) | |
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) | |
return perplexity | |
def preprocess(self, x): | |
# NCT -> NTC -> [NT, C] | |
# x = x.permute(0, 2, 1).contiguous() | |
# x = x.view(-1, x.shape[-1]) | |
x = rearrange(x, 'n c t -> (n t) c') | |
return x | |
def forward(self, x, return_idx=False, temperature=0.): | |
N, width, T = x.shape | |
x = self.preprocess(x) | |
if self.training and not self.init: | |
self.init_codebook(x) | |
code_idx = self.quantize(x, temperature) | |
x_d = self.dequantize(code_idx) | |
if self.training: | |
perplexity = self.update_codebook(x, code_idx) | |
else: | |
perplexity = self.compute_perplexity(code_idx) | |
commit_loss = F.mse_loss(x, x_d.detach()) # It's right. the t2m-gpt paper is wrong on embed loss and commitment loss. | |
# Passthrough | |
x_d = x + (x_d - x).detach() | |
# Postprocess | |
x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() | |
code_idx = code_idx.view(N, T).contiguous() | |
# print(code_idx[0]) | |
if return_idx: | |
return x_d, code_idx, commit_loss, perplexity | |
return x_d, commit_loss, perplexity | |
class QuantizeEMA(QuantizeEMAReset): | |
def update_codebook(self, x, code_idx): | |
code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L | |
code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) | |
code_sum = torch.matmul(code_onehot, x) # nb_code, c | |
code_count = code_onehot.sum(dim=-1) # nb_code | |
# Update centres | |
self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum | |
self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count | |
usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() | |
code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) | |
self.codebook = usage * code_update + (1-usage) * self.codebook | |
prob = code_count / torch.sum(code_count) | |
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) | |
return perplexity | |