MoMask-test / models /t2m_eval_modules.py
andrewatef's picture
Upload folder using huggingface_hub
823807d verified
raw
history blame
No virus
6.5 kB
import torch
import torch.nn as nn
import numpy as np
import time
import math
import random
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# from networks.layers import *
def init_weight(m):
if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d):
nn.init.xavier_normal_(m.weight)
# m.bias.data.fill_(0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
# batch_size, dimension and position
# output: (batch_size, dim)
def positional_encoding(batch_size, dim, pos):
assert batch_size == pos.shape[0]
positions_enc = np.array([
[pos[j] / np.power(10000, (i-i%2)/dim) for i in range(dim)]
for j in range(batch_size)
], dtype=np.float32)
positions_enc[:, 0::2] = np.sin(positions_enc[:, 0::2])
positions_enc[:, 1::2] = np.cos(positions_enc[:, 1::2])
return torch.from_numpy(positions_enc).float()
def get_padding_mask(batch_size, seq_len, cap_lens):
cap_lens = cap_lens.data.tolist()
mask_2d = torch.ones((batch_size, seq_len, seq_len), dtype=torch.float32)
for i, cap_len in enumerate(cap_lens):
mask_2d[i, :, :cap_len] = 0
return mask_2d.bool(), 1 - mask_2d[:, :, 0].clone()
def top_k_logits(logits, k):
v, ix = torch.topk(logits, k)
out = logits.clone()
out[out < v[:, [-1]]] = -float('Inf')
return out
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=300):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, pos):
return self.pe[pos]
class MovementConvEncoder(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MovementConvEncoder, self).__init__()
self.main = nn.Sequential(
nn.Conv1d(input_size, hidden_size, 4, 2, 1),
nn.Dropout(0.2, inplace=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(hidden_size, output_size, 4, 2, 1),
nn.Dropout(0.2, inplace=True),
nn.LeakyReLU(0.2, inplace=True),
)
self.out_net = nn.Linear(output_size, output_size)
self.main.apply(init_weight)
self.out_net.apply(init_weight)
def forward(self, inputs):
inputs = inputs.permute(0, 2, 1)
outputs = self.main(inputs).permute(0, 2, 1)
# print(outputs.shape)
return self.out_net(outputs)
class MovementConvDecoder(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MovementConvDecoder, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose1d(input_size, hidden_size, 4, 2, 1),
# nn.Dropout(0.2, inplace=True),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose1d(hidden_size, output_size, 4, 2, 1),
# nn.Dropout(0.2, inplace=True),
nn.LeakyReLU(0.2, inplace=True),
)
self.out_net = nn.Linear(output_size, output_size)
self.main.apply(init_weight)
self.out_net.apply(init_weight)
def forward(self, inputs):
inputs = inputs.permute(0, 2, 1)
outputs = self.main(inputs).permute(0, 2, 1)
return self.out_net(outputs)
class TextEncoderBiGRUCo(nn.Module):
def __init__(self, word_size, pos_size, hidden_size, output_size, device):
super(TextEncoderBiGRUCo, self).__init__()
self.device = device
self.pos_emb = nn.Linear(pos_size, word_size)
self.input_emb = nn.Linear(word_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
self.output_net = nn.Sequential(
nn.Linear(hidden_size * 2, hidden_size),
nn.LayerNorm(hidden_size),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(hidden_size, output_size)
)
self.input_emb.apply(init_weight)
self.pos_emb.apply(init_weight)
self.output_net.apply(init_weight)
# self.linear2.apply(init_weight)
# self.batch_size = batch_size
self.hidden_size = hidden_size
self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
# input(batch_size, seq_len, dim)
def forward(self, word_embs, pos_onehot, cap_lens):
num_samples = word_embs.shape[0]
pos_embs = self.pos_emb(pos_onehot)
inputs = word_embs + pos_embs
input_embs = self.input_emb(inputs)
hidden = self.hidden.repeat(1, num_samples, 1)
cap_lens = cap_lens.data.tolist()
emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
gru_seq, gru_last = self.gru(emb, hidden)
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
return self.output_net(gru_last)
class MotionEncoderBiGRUCo(nn.Module):
def __init__(self, input_size, hidden_size, output_size, device):
super(MotionEncoderBiGRUCo, self).__init__()
self.device = device
self.input_emb = nn.Linear(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
self.output_net = nn.Sequential(
nn.Linear(hidden_size*2, hidden_size),
nn.LayerNorm(hidden_size),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(hidden_size, output_size)
)
self.input_emb.apply(init_weight)
self.output_net.apply(init_weight)
self.hidden_size = hidden_size
self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
# input(batch_size, seq_len, dim)
def forward(self, inputs, m_lens):
num_samples = inputs.shape[0]
input_embs = self.input_emb(inputs)
hidden = self.hidden.repeat(1, num_samples, 1)
cap_lens = m_lens.data.tolist()
emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
gru_seq, gru_last = self.gru(emb, hidden)
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
return self.output_net(gru_last)