import random import torch.nn as nn from models.vq.encdec import Encoder, Decoder from models.vq.residual_vq import ResidualVQ class RVQVAE(nn.Module): def __init__(self, args, input_width=263, nb_code=1024, code_dim=512, output_emb_width=512, down_t=3, stride_t=2, width=512, depth=3, dilation_growth_rate=3, activation='relu', norm=None): super().__init__() assert output_emb_width == code_dim self.code_dim = code_dim self.num_code = nb_code # self.quant = args.quantizer self.encoder = Encoder(input_width, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm) self.decoder = Decoder(input_width, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm) rvqvae_config = { 'num_quantizers': args.num_quantizers, 'shared_codebook': args.shared_codebook, 'quantize_dropout_prob': args.quantize_dropout_prob, 'quantize_dropout_cutoff_index': 0, 'nb_code': nb_code, 'code_dim':code_dim, 'args': args, } self.quantizer = ResidualVQ(**rvqvae_config) def preprocess(self, x): # (bs, T, Jx3) -> (bs, Jx3, T) x = x.permute(0, 2, 1).float() return x def postprocess(self, x): # (bs, Jx3, T) -> (bs, T, Jx3) x = x.permute(0, 2, 1) return x def encode(self, x): N, T, _ = x.shape x_in = self.preprocess(x) x_encoder = self.encoder(x_in) # print(x_encoder.shape) code_idx, all_codes = self.quantizer.quantize(x_encoder, return_latent=True) # print(code_idx.shape) # code_idx = code_idx.view(N, -1) # (N, T, Q) # print() return code_idx, all_codes def forward(self, x): x_in = self.preprocess(x) # Encode x_encoder = self.encoder(x_in) ## quantization # x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x_encoder, sample_codebook_temp=0.5, # force_dropout_index=0) #TODO hardcode x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x_encoder, sample_codebook_temp=0.5) # print(code_idx[0, :, 1]) ## decoder x_out = self.decoder(x_quantized) # x_out = self.postprocess(x_decoder) return x_out, commit_loss, perplexity def forward_decoder(self, x): x_d = self.quantizer.get_codes_from_indices(x) # x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous() x = x_d.sum(dim=0).permute(0, 2, 1) # decoder x_out = self.decoder(x) # x_out = self.postprocess(x_decoder) return x_out class LengthEstimator(nn.Module): def __init__(self, input_size, output_size): super(LengthEstimator, self).__init__() nd = 512 self.output = nn.Sequential( nn.Linear(input_size, nd), nn.LayerNorm(nd), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.2), nn.Linear(nd, nd // 2), nn.LayerNorm(nd // 2), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.2), nn.Linear(nd // 2, nd // 4), nn.LayerNorm(nd // 4), nn.LeakyReLU(0.2, inplace=True), nn.Linear(nd // 4, output_size) ) self.output.apply(self.__init_weights) def __init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def forward(self, text_emb): return self.output(text_emb)