summary / fengshen /models /DAVAE /DAVAEModel.py
fclong's picture
Upload 396 files
8ebda9e
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertConfig,TransfoXLConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import cached_path,hf_bucket_url
from fengshen.models.DAVAE.GPT2ModelForLatent import GPT2ModelForLatent
from fengshen.models.DAVAE.BertForLatentConnector import BertForLatentConnector
from fengshen.models.DAVAE.run_latent_generation import *
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def log_sum_exp(value, dim=None, keepdim=False):
"""Numerically stable implementation of the operation
value.exp().sum(dim, keepdim).log()
"""
if dim is not None:
m, _ = torch.max(value, dim=dim, keepdim=True)
value0 = value - m
if keepdim is False:
m = m.squeeze(dim)
return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim))
else:
m = torch.max(value)
sum_exp = torch.sum(torch.exp(value - m))
return m + torch.log(sum_exp)
class VAEPretrainedModel(PreTrainedModel):
def _init_weights(self, module):
""" Initialize the weights """
pass # to bypass the not implement error
class DAVAEModel(VAEPretrainedModel):
config_class = PretrainedConfig
def __init__(self, config:PretrainedConfig,*model_args, **model_kwargs):
super().__init__(config=config)
self.config = config
self.vae_model = DAVAEModel.load_model(self.config)
@classmethod
def load_model(cls, config):
encoder_config = BertConfig.from_dict(config.encoder)
encoder_model = BertForLatentConnector(config=encoder_config, latent_size=config.latent_size)
dec_config = TransfoXLConfig.from_dict(config.decoder)
dec_config.latent_size = config.latent_size
decoder_model = GPT2ModelForLatent(config=dec_config)
vae_model = EncDecAAE(config,encoder_model, decoder_model, dec_config.latent_size, pad_token_id=50000)
return vae_model
def set_tokenizers(self,encoder_tokenizer,decoder_tokenizer):
if not hasattr(self, 'encoder_tokenizer'):
self.encoder_tokenizer = encoder_tokenizer
if not hasattr(self, 'decoder_tokenizer'):
self.decoder_tokenizer = decoder_tokenizer
def simulate_batch(self,encoder_tokenizer,decoder_tokenizer, sent_inputs, prompt=None):
self.set_tokenizers(encoder_tokenizer,decoder_tokenizer)
# 生成相似句
latent_z = self.latent_code_from_text_batch(sent_inputs)
text_analogy = self.text_from_latent_code_batch(latent_z,prompt=prompt)
return text_analogy
def latent_code_from_text_batch(self,texts):
# texts->latents
tokens_tensor_list = []
for text in texts:
tokens = self.encoder_tokenizer.encode(text)[:510]
tokens_tensor_list.append(torch.tensor([101]+tokens+[102]))
coded = pad_sequence(tokens_tensor_list, batch_first=True, padding_value=0).long()
device = next(self.vae_model.decoder.parameters()).device
with torch.no_grad():
coded = coded.to(device)
pooled_hidden_fea = self.vae_model.encoder(coded, attention_mask=(coded > 0).float())[1]
mean, logvar = self.vae_model.encoder.linear(pooled_hidden_fea).chunk(2, -1)
std = logvar.mul(0.5).exp()
eps = torch.zeros_like(std).normal_()
latent_z = mean + torch.mul(eps, std)*self.config.std_scale
return latent_z
def text_from_latent_code_batch(self,latent_z, prompt=None):
# latents->texts
device = next(self.vae_model.decoder.parameters()).device
past = latent_z
batch_size = latent_z.shape[0]
bos_token = self.decoder_tokenizer.convert_tokens_to_ids(self.decoder_tokenizer.bos_token)
end_token = self.decoder_tokenizer.convert_tokens_to_ids(self.decoder_tokenizer.eos_token)
if prompt is not None:
prompt = [[bos_token] + self.decoder_tokenizer.encode(text)[:-1] for text in prompt]
else:
prompt = [[bos_token]]*batch_size
context_tokens_tensor = torch.tensor([[end_token]*self.config.max_out_length]*batch_size).to(device) # 2-d tensor
context_length_tensor = torch.tensor([1]*batch_size).to(device)
for i in range(batch_size):
context_tokens_tensor[i,:len(prompt[i])] = torch.tensor(prompt[i]).long().to(device)
context_length_tensor[i] = len(prompt[i])
out = sample_sequence_conditional_batch(
model=self.vae_model.decoder,
max_out_length= self.config.max_out_length,
context_tokens_tensor=context_tokens_tensor,
context_length_tensor=context_length_tensor,
latent_z=latent_z,
temperature=self.config.temperature,
top_k=self.config.top_k,
top_p=self.config.top_p,
repetition_penalty=self.config.repetition_penalty,
device=device
)
out_text = []
for i, tokens in enumerate(out):
tokens = tokens[len(prompt[i]):]
tokens = tokens[:tokens.index(end_token)] if end_token in tokens else tokens
text = self.decoder_tokenizer.decode(tokens, clean_up_tokenization_spaces=True)
out_text.append(filter_noise(text))
return out_text
class EncDecAAE(nn.Module):
"""Adversarial Auto-Encoder"""
def __init__(self,config, encoder, decoder, latent_size, pad_token_id):
super(EncDecAAE, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.config = config
self.pad_token_id = pad_token_id
self.Disc = nn.Sequential(nn.Linear(latent_size, 4*latent_size), nn.ReLU(),
nn.Linear(4*latent_size, 1))
# Standard Normal prior
loc = torch.zeros(latent_size)
scale = torch.ones(latent_size)
self.prior = torch.distributions.normal.Normal(loc, scale)
def connect(self, bert_fea, nsamples=1, fb_mode=0):
"""
Returns: Tensor1, Tensor2
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
Tensor2: the tenor of KL for each x with shape [batch]
"""
# (batch_size, nz)
mean, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
z = self.reparameterize(mean, logvar, nsamples)
if fb_mode == 0:
KL = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1).sum(dim=1)
elif fb_mode == 1:
kl_loss = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1)
kl_mask = (kl_loss > self.config.dim_target_kl).float()
KL = (kl_mask * kl_loss).sum(dim=1)
return z, KL
def connect_deterministic(self, bert_fea, nsamples=1):
"""
Returns: Tensor1, Tensor2
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
Tensor2: the tenor of KL for each x with shape [batch]
"""
# (batch_size, nz)
mean, logvar = self.encoder.linear(bert_fea).chunk(2, -1)
logvar = torch.zeros_like(logvar)
z = self.reparameterize(mean, logvar, nsamples)
KL = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1).sum(dim=1)
return z, KL
def reparameterize(self, mu, logvar, nsamples=1):
"""sample from posterior Gaussian family
Args:
mu: Tensor
Mean of gaussian distribution with shape (batch, nz)
logvar: Tensor
logvar of gaussian distibution with shape (batch, nz)
Returns: Tensor
Sampled z with shape (batch, nsamples, nz)
"""
batch_size, nz = mu.size()
std = logvar.mul(0.5).exp()
mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz)
std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz)
eps = torch.zeros_like(std_expd).normal_()
return mu_expd + torch.mul(eps, std_expd)
def loss_adv(self, z):
zn = torch.randn_like(z)
zeros = torch.zeros(len(z), 1, device=z.device).half()
ones = torch.ones(len(z), 1, device=z.device).half()
loss_d = F.binary_cross_entropy_with_logits(self.Disc(z.detach().half()), zeros) + \
F.binary_cross_entropy_with_logits(self.Disc(zn.half()), ones)
loss_g = F.binary_cross_entropy_with_logits(self.Disc(z.half()), ones)
return loss_d, loss_g
def forward(self, inputs, labels, beta=0.0, iw=None, fb_mode=0, emb_noise=None):
attention_mask = (inputs > 0).float()
reconstrution_mask = (labels != self.pad_token_id).float() # the padding token for GPT2
sent_length = torch.sum(reconstrution_mask, dim=1)
outputs = self.encoder(inputs, attention_mask, emb_noise=emb_noise)
pooled_hidden_fea = outputs[1]
seq_length = labels.size(1)
dec_attn_mask = self.decoder.get_attn_mask(seq_length).to(labels.device)
if fb_mode in [0,1]:
latent_z, loss_kl = self.connect(pooled_hidden_fea, fb_mode=fb_mode)
latent_z = latent_z.squeeze(1)
outputs = self.decoder(input_ids=labels, attention_mask=dec_attn_mask, latent_state=latent_z, labels=labels, label_ignore=self.pad_token_id) # ignore loss over padding tokens
loss_rec = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
elif fb_mode==2:
latent_z, loss_kl = self.connect_deterministic(pooled_hidden_fea)
latent_z = latent_z.squeeze(1)
outputs = self.decoder(input_ids=labels, attention_mask=dec_attn_mask, latent_state=latent_z, labels=labels, label_ignore=self.pad_token_id)
loss_rec = outputs[0] # model outputs are always tuple
if self.config.length_weighted_loss:
loss = loss_rec / sent_length + beta * loss_kl
else:
loss = loss_rec + beta * loss_kl
if iw!=None:
total_loss = torch.sum(loss*iw)/torch.sum(iw)
else:
total_loss = torch.sum(loss)
return (loss_rec/sent_length).mean(), loss_kl.mean(), total_loss