|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel |
|
from configs import BertVAEConfig |
|
from transformers.models.bert.modeling_bert import BertEncoder, BertModel |
|
|
|
|
|
class BertVAE(PreTrainedModel): |
|
config_class = BertVAEConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.encoder = BertEncoder(config) |
|
self.bert = BertModel.from_pretrained('bert-base-uncased') |
|
self.fc_mu = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.fc_var = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.enc_cls = nn.Linear(config.hidden_size, config.position_num) |
|
self.dec_cls = nn.Linear(config.hidden_size, config.position_num) |
|
self.decoder = BertEncoder(config) |
|
|
|
for p in self.bert.parameters(): |
|
p.requires_grad = False |
|
|
|
|
|
def encode(self, input_ids, **kwargs): |
|
''' |
|
x: {input_ids: (batch_size, seq_len), attention_mask: (batch_size, seq_len)} |
|
''' |
|
|
|
x = self.bert(input_ids).last_hidden_state |
|
outputs = self.encoder(x, **kwargs) |
|
hidden_state = outputs.last_hidden_state |
|
mu = self.fc_mu(hidden_state) |
|
log_var = self.fc_var(hidden_state) |
|
return mu, log_var |
|
|
|
|
|
def encoder_cls(self, input_ids, **kwargs): |
|
''' |
|
input_ids: {input_ids: (batch_size, seq_len)} |
|
''' |
|
x = self.bert(input_ids).last_hidden_state |
|
outputs = self.encoder(x, **kwargs) |
|
hidden_state = outputs.last_hidden_state |
|
return self.enc_cls(hidden_state[:, 0, :]) |
|
|
|
|
|
def decoder_cls(self, z, **kwargs): |
|
''' |
|
z: latent vector of shape (batch_size, seq_len, dim) |
|
''' |
|
outputs = self.decoder(z, **kwargs) |
|
hidden_state = outputs.last_hidden_state |
|
return self.dec_cls(hidden_state[:, 0, :]) |
|
|
|
|
|
def reparameterize(self, mu, log_var): |
|
std = torch.exp(0.5 * log_var) |
|
eps = torch.randn_like(std) |
|
return mu + eps * std |
|
|
|
|
|
def decode(self, z, **kwargs): |
|
''' |
|
z: latent vector of shape (batch_size, seq_len, dim) |
|
''' |
|
outputs = self.decoder(z, **kwargs) |
|
return outputs.last_hidden_state |
|
|
|
|
|
def forward(self, input_ids, position=None, **kwargs): |
|
mu, log_var = self.encode(**input_ids, **kwargs) |
|
z = self.reparameterize(mu, log_var) |
|
return self.decode(z, **kwargs), mu, log_var |
|
|
|
|
|
def _elbo(self, x, x_hat, mu, log_var): |
|
''' |
|
Given input x, logits, mu, log_var, compute the negative ELBO |
|
x: input tensor of shape (batch_size, seq_len, dim) |
|
logits: logits tensor of shape (batch_size, seq_len, dim) |
|
mu: mean tensor of shape (batch_size, seq_len, dim) |
|
log_var: log variance tensor of shape (batch_size, seq_len, dim) |
|
''' |
|
recon_loss = nn.functional.mse_loss(x_hat, x, reduction='mean') |
|
kl_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())) |
|
return recon_loss + kl_loss*0.1 |
|
|
|
|
|
def elbo(self, input_ids, **kwargs): |
|
''' |
|
Given input x, compute the ELBO |
|
x: input tensor of shape (batch_size, seq_len, dim) |
|
''' |
|
x = self.bert(input_ids, **kwargs).last_hidden_state |
|
outputs = self.encoder(x, **kwargs) |
|
hidden_state = outputs.last_hidden_state |
|
mu = self.fc_mu(hidden_state) |
|
log_var = self.fc_var(hidden_state) |
|
z = self.reparameterize(mu, log_var) |
|
outputs = self.decoder(z, **kwargs) |
|
x_hat = outputs.last_hidden_state |
|
return self._elbo(x, x_hat, mu, log_var) |
|
|
|
|
|
def reconstruct(self, input_ids, **kwargs): |
|
''' |
|
Given input_ids, reconstruct x |
|
x: input tensor of shape (batch_size, seq_len, dim) |
|
''' |
|
return self.forward(input_ids, **kwargs)[0] |
|
|
|
|
|
|
|
def sample(self, num_samples, device, **kwargs): |
|
''' |
|
Given input x, generate a sample |
|
x: input tensor of shape (batch_size, seq_len, dim) |
|
''' |
|
z = torch.randn(num_samples, self.config.max_position_embeddings, self.config.hidden_size).to(device) |
|
return self.decode(z, **kwargs) |
|
|