Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,324 Bytes
445d3d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# This code is based on https://github.com/Mael-zys/T2M-GPT.git
import torch.nn as nn
from models.encdec import Encoder, Decoder
from models.quantize_cnn import QuantizeEMAReset, Quantizer, QuantizeEMA, QuantizeReset
class VQVAE_251(nn.Module):
def __init__(self,
args,
nb_code=1024,
code_dim=512,
output_emb_width=512,
down_t=3,
stride_t=2,
width=512,
depth=3,
dilation_growth_rate=3,
activation='relu',
norm=None):
super().__init__()
self.code_dim = code_dim
self.num_code = nb_code
self.quant = args.quantizer
self.encoder = Encoder(251 if args.dataname == 'kit' else 263, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
self.decoder = Decoder(251 if args.dataname == 'kit' else 263, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
if args.quantizer == "ema_reset":
self.quantizer = QuantizeEMAReset(nb_code, code_dim, args)
elif args.quantizer == "orig":
self.quantizer = Quantizer(nb_code, code_dim, 1.0)
elif args.quantizer == "ema":
self.quantizer = QuantizeEMA(nb_code, code_dim, args)
elif args.quantizer == "reset":
self.quantizer = QuantizeReset(nb_code, code_dim, args)
def preprocess(self, x):
# (bs, T, Jx3) -> (bs, Jx3, T)
x = x.permute(0,2,1).float()
return x
def postprocess(self, x):
# (bs, Jx3, T) -> (bs, T, Jx3)
x = x.permute(0,2,1)
return x
def encode(self, x):
N, T, _ = x.shape
x_in = self.preprocess(x)
x_encoder = self.encoder(x_in)
# import pdb; pdb.set_trace()
x_encoder = self.postprocess(x_encoder)
x_encoder = x_encoder.contiguous().view(-1, x_encoder.shape[-1]) # (NT, C)
code_idx = self.quantizer.quantize(x_encoder)
code_idx = code_idx.view(N, -1)
return code_idx
def encode_x(self, x):
N, T, _ = x.shape
x_in = self.preprocess(x)
x_encoder = self.encoder(x_in)
# import pdb; pdb.set_trace()
x_encoder = self.postprocess(x_encoder)
x_encoder = x_encoder.contiguous().view(-1, x_encoder.shape[-1]) # (NT, C)
return x_encoder # (B*T, 512)
def forward(self, x):
x_in = self.preprocess(x)
# Encode
x_encoder = self.encoder(x_in)
## quantization
x_quantized, loss, perplexity = self.quantizer(x_encoder)
## decoder
x_decoder = self.decoder(x_quantized)
x_out = self.postprocess(x_decoder)
return x_out, loss, perplexity
def forward_decoder(self, x):
x_d = self.quantizer.dequantize(x)
x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous()
# decoder
x_decoder = self.decoder(x_d)
x_out = self.postprocess(x_decoder)
return x_out
class HumanVQVAE(nn.Module):
def __init__(self,
args,
nb_code=512,
code_dim=512,
output_emb_width=512,
down_t=3,
stride_t=2,
width=512,
depth=3,
dilation_growth_rate=3,
activation='relu',
norm=None):
super().__init__()
self.nb_joints = 21 if args.dataname == 'kit' else 22
self.vqvae = VQVAE_251(args, nb_code, code_dim, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
def encode(self, x):
b, t, c = x.size()
quants = self.vqvae.encode(x) # (N, T)
return quants
def encode_x(self, x):
b, t, c = x.size()
quants = self.vqvae.encode_x(x) # (N, T)
return quants
def forward(self, x):
x_out, loss, perplexity = self.vqvae(x)
return x_out, loss, perplexity
def forward_decoder(self, x):
x_out = self.vqvae.forward_decoder(x)
return x_out
|