Spaces:
Running
Running
File size: 7,223 Bytes
c0eac48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import random
from math import ceil
from functools import partial
from itertools import zip_longest
from random import randrange
import torch
from torch import nn
import torch.nn.functional as F
# from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
from models.vq.quantizer import QuantizeEMAReset, QuantizeEMA
from einops import rearrange, repeat, pack, unpack
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def round_up_multiple(num, mult):
return ceil(num / mult) * mult
# main class
class ResidualVQ(nn.Module):
""" Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
def __init__(
self,
num_quantizers,
shared_codebook=False,
quantize_dropout_prob=0.5,
quantize_dropout_cutoff_index=0,
**kwargs
):
super().__init__()
self.num_quantizers = num_quantizers
# self.layers = nn.ModuleList([VectorQuantize(accept_image_fmap = accept_image_fmap, **kwargs) for _ in range(num_quantizers)])
if shared_codebook:
layer = QuantizeEMAReset(**kwargs)
self.layers = nn.ModuleList([layer for _ in range(num_quantizers)])
else:
self.layers = nn.ModuleList([QuantizeEMAReset(**kwargs) for _ in range(num_quantizers)])
# self.layers = nn.ModuleList([QuantizeEMA(**kwargs) for _ in range(num_quantizers)])
# self.quantize_dropout = quantize_dropout and num_quantizers > 1
assert quantize_dropout_cutoff_index >= 0 and quantize_dropout_prob >= 0
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
self.quantize_dropout_prob = quantize_dropout_prob
@property
def codebooks(self):
codebooks = [layer.codebook for layer in self.layers]
codebooks = torch.stack(codebooks, dim = 0)
return codebooks # 'q c d'
def get_codes_from_indices(self, indices): #indices shape 'b n q' # dequantize
batch, quantize_dim = indices.shape[0], indices.shape[-1]
# because of quantize dropout, one can pass in indices that are coarse
# and the network should be able to reconstruct
if quantize_dim < self.num_quantizers:
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)
# get ready for gathering
codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = batch)
gather_indices = repeat(indices, 'b n q -> q b n d', d = codebooks.shape[-1])
# take care of quantizer dropout
mask = gather_indices == -1.
gather_indices = gather_indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later
# print(gather_indices.max(), gather_indices.min())
all_codes = codebooks.gather(2, gather_indices) # gather all codes
# mask out any codes that were dropout-ed
all_codes = all_codes.masked_fill(mask, 0.)
return all_codes # 'q b n d'
def get_codebook_entry(self, indices): #indices shape 'b n q'
all_codes = self.get_codes_from_indices(indices) #'q b n d'
latent = torch.sum(all_codes, dim=0) #'b n d'
latent = latent.permute(0, 2, 1)
return latent
def forward(self, x, return_all_codes = False, sample_codebook_temp = None, force_dropout_index=-1):
# debug check
# print(self.codebooks[:,0,0].detach().cpu().numpy())
num_quant, quant_dropout_prob, device = self.num_quantizers, self.quantize_dropout_prob, x.device
quantized_out = 0.
residual = x
all_losses = []
all_indices = []
all_perplexity = []
should_quantize_dropout = self.training and random.random() < self.quantize_dropout_prob
start_drop_quantize_index = num_quant
# To ensure the first-k layers learn things as much as possible, we randomly dropout the last q - k layers
if should_quantize_dropout:
start_drop_quantize_index = randrange(self.quantize_dropout_cutoff_index, num_quant) # keep quant layers <= quantize_dropout_cutoff_index, TODO vary in batch
null_indices_shape = [x.shape[0], x.shape[-1]] # 'b*n'
null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long)
# null_loss = 0.
if force_dropout_index >= 0:
should_quantize_dropout = True
start_drop_quantize_index = force_dropout_index
null_indices_shape = [x.shape[0], x.shape[-1]] # 'b*n'
null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
# print(force_dropout_index)
# go through the layers
for quantizer_index, layer in enumerate(self.layers):
if should_quantize_dropout and quantizer_index > start_drop_quantize_index:
all_indices.append(null_indices)
# all_losses.append(null_loss)
continue
# layer_indices = None
# if return_loss:
# layer_indices = indices[..., quantizer_index] #gt indices
# quantized, *rest = layer(residual, indices = layer_indices, sample_codebook_temp = sample_codebook_temp) #single quantizer TODO
quantized, *rest = layer(residual, return_idx=True, temperature=sample_codebook_temp) #single quantizer
# print(quantized.shape, residual.shape)
residual -= quantized.detach()
quantized_out += quantized
embed_indices, loss, perplexity = rest
all_indices.append(embed_indices)
all_losses.append(loss)
all_perplexity.append(perplexity)
# stack all losses and indices
all_indices = torch.stack(all_indices, dim=-1)
all_losses = sum(all_losses)/len(all_losses)
all_perplexity = sum(all_perplexity)/len(all_perplexity)
ret = (quantized_out, all_indices, all_losses, all_perplexity)
if return_all_codes:
# whether to return all codes from all codebooks across layers
all_codes = self.get_codes_from_indices(all_indices)
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
ret = (*ret, all_codes)
return ret
def quantize(self, x, return_latent=False):
all_indices = []
quantized_out = 0.
residual = x
all_codes = []
for quantizer_index, layer in enumerate(self.layers):
quantized, *rest = layer(residual, return_idx=True) #single quantizer
residual = residual - quantized.detach()
quantized_out = quantized_out + quantized
embed_indices, loss, perplexity = rest
all_indices.append(embed_indices)
# print(quantizer_index, embed_indices[0])
# print(quantizer_index, quantized[0])
# break
all_codes.append(quantized)
code_idx = torch.stack(all_indices, dim=-1)
all_codes = torch.stack(all_codes, dim=0)
if return_latent:
return code_idx, all_codes
return code_idx |