Yewon
feat: first dist
f4b9f63
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from configs import BertVAEConfig
from transformers.models.bert.modeling_bert import BertEncoder, BertModel
class BertVAE(PreTrainedModel):
config_class = BertVAEConfig
def __init__(self, config):
super().__init__(config)
self.encoder = BertEncoder(config)
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.fc_mu = nn.Linear(config.hidden_size, config.hidden_size)
self.fc_var = nn.Linear(config.hidden_size, config.hidden_size)
self.enc_cls = nn.Linear(config.hidden_size, config.position_num)
self.dec_cls = nn.Linear(config.hidden_size, config.position_num)
self.decoder = BertEncoder(config)
for p in self.bert.parameters():
p.requires_grad = False
def encode(self, input_ids, **kwargs):
'''
x: {input_ids: (batch_size, seq_len), attention_mask: (batch_size, seq_len)}
'''
x = self.bert(input_ids).last_hidden_state
outputs = self.encoder(x, **kwargs)
hidden_state = outputs.last_hidden_state
mu = self.fc_mu(hidden_state)
log_var = self.fc_var(hidden_state)
return mu, log_var
def encoder_cls(self, input_ids, **kwargs):
'''
input_ids: {input_ids: (batch_size, seq_len)}
'''
x = self.bert(input_ids).last_hidden_state
outputs = self.encoder(x, **kwargs)
hidden_state = outputs.last_hidden_state
return self.enc_cls(hidden_state[:, 0, :])
def decoder_cls(self, z, **kwargs):
'''
z: latent vector of shape (batch_size, seq_len, dim)
'''
outputs = self.decoder(z, **kwargs)
hidden_state = outputs.last_hidden_state
return self.dec_cls(hidden_state[:, 0, :])
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, **kwargs):
'''
z: latent vector of shape (batch_size, seq_len, dim)
'''
outputs = self.decoder(z, **kwargs)
return outputs.last_hidden_state
def forward(self, input_ids, position=None, **kwargs):
mu, log_var = self.encode(**input_ids, **kwargs)
z = self.reparameterize(mu, log_var)
return self.decode(z, **kwargs), mu, log_var
def _elbo(self, x, x_hat, mu, log_var):
'''
Given input x, logits, mu, log_var, compute the negative ELBO
x: input tensor of shape (batch_size, seq_len, dim)
logits: logits tensor of shape (batch_size, seq_len, dim)
mu: mean tensor of shape (batch_size, seq_len, dim)
log_var: log variance tensor of shape (batch_size, seq_len, dim)
'''
recon_loss = nn.functional.mse_loss(x_hat, x, reduction='mean')
kl_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()))
return recon_loss + kl_loss*0.1
def elbo(self, input_ids, **kwargs):
'''
Given input x, compute the ELBO
x: input tensor of shape (batch_size, seq_len, dim)
'''
x = self.bert(input_ids, **kwargs).last_hidden_state
outputs = self.encoder(x, **kwargs)
hidden_state = outputs.last_hidden_state
mu = self.fc_mu(hidden_state)
log_var = self.fc_var(hidden_state)
z = self.reparameterize(mu, log_var)
outputs = self.decoder(z, **kwargs)
x_hat = outputs.last_hidden_state
return self._elbo(x, x_hat, mu, log_var)
def reconstruct(self, input_ids, **kwargs):
'''
Given input_ids, reconstruct x
x: input tensor of shape (batch_size, seq_len, dim)
'''
return self.forward(input_ids, **kwargs)[0]
def sample(self, num_samples, device, **kwargs):
'''
Given input x, generate a sample
x: input tensor of shape (batch_size, seq_len, dim)
'''
z = torch.randn(num_samples, self.config.max_position_embeddings, self.config.hidden_size).to(device)
return self.decode(z, **kwargs)