kevinwang676's picture
Duplicate from kevinwang676/ChatGLM2-SadTalker
9074b85
raw
history blame
6.07 kB
import torch
import torch.nn.functional as F
from torch import nn
from src.audio2pose_models.res_unet import ResUnet
def class2onehot(idx, class_num):
assert torch.max(idx).item() < class_num
onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
onehot.scatter_(1, idx, 1)
return onehot
class CVAE(nn.Module):
def __init__(self, cfg):
super().__init__()
encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
latent_size = cfg.MODEL.CVAE.LATENT_SIZE
num_classes = cfg.DATASET.NUM_CLASSES
audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
seq_len = cfg.MODEL.CVAE.SEQ_LEN
self.latent_size = latent_size
self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
audio_emb_in_size, audio_emb_out_size, seq_len)
self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
audio_emb_in_size, audio_emb_out_size, seq_len)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, batch):
batch = self.encoder(batch)
mu = batch['mu']
logvar = batch['logvar']
z = self.reparameterize(mu, logvar)
batch['z'] = z
return self.decoder(batch)
def test(self, batch):
'''
class_id = batch['class']
z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
batch['z'] = z
'''
return self.decoder(batch)
class ENCODER(nn.Module):
def __init__(self, layer_sizes, latent_size, num_classes,
audio_emb_in_size, audio_emb_out_size, seq_len):
super().__init__()
self.resunet = ResUnet()
self.num_classes = num_classes
self.seq_len = seq_len
self.MLP = nn.Sequential()
layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
self.MLP.add_module(
name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
def forward(self, batch):
class_id = batch['class']
pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
ref = batch['ref'] #bs 6
bs = pose_motion_gt.shape[0]
audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
#pose encode
pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
#audio mapping
print(audio_in.shape)
audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
audio_out = audio_out.reshape(bs, -1)
class_bias = self.classbias[class_id] #bs latent_size
x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
x_out = self.MLP(x_in)
mu = self.linear_means(x_out)
logvar = self.linear_means(x_out) #bs latent_size
batch.update({'mu':mu, 'logvar':logvar})
return batch
class DECODER(nn.Module):
def __init__(self, layer_sizes, latent_size, num_classes,
audio_emb_in_size, audio_emb_out_size, seq_len):
super().__init__()
self.resunet = ResUnet()
self.num_classes = num_classes
self.seq_len = seq_len
self.MLP = nn.Sequential()
input_size = latent_size + seq_len*audio_emb_out_size + 6
for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
self.MLP.add_module(
name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
if i+1 < len(layer_sizes):
self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
else:
self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
self.pose_linear = nn.Linear(6, 6)
self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
def forward(self, batch):
z = batch['z'] #bs latent_size
bs = z.shape[0]
class_id = batch['class']
ref = batch['ref'] #bs 6
audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
#print('audio_in: ', audio_in[:, :, :10])
audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
#print('audio_out: ', audio_out[:, :, :10])
audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
class_bias = self.classbias[class_id] #bs latent_size
z = z + class_bias
x_in = torch.cat([ref, z, audio_out], dim=-1)
x_out = self.MLP(x_in) # bs layer_sizes[-1]
x_out = x_out.reshape((bs, self.seq_len, -1))
#print('x_out: ', x_out)
pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
batch.update({'pose_motion_pred':pose_motion_pred})
return batch