|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import BertConfig,TransfoXLConfig |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.utils import cached_path,hf_bucket_url |
|
from fengshen.models.DAVAE.GPT2ModelForLatent import GPT2ModelForLatent |
|
from fengshen.models.DAVAE.BertForLatentConnector import BertForLatentConnector |
|
from fengshen.models.DAVAE.run_latent_generation import * |
|
|
|
|
|
|
|
def log_sum_exp(value, dim=None, keepdim=False): |
|
"""Numerically stable implementation of the operation |
|
value.exp().sum(dim, keepdim).log() |
|
""" |
|
if dim is not None: |
|
m, _ = torch.max(value, dim=dim, keepdim=True) |
|
value0 = value - m |
|
if keepdim is False: |
|
m = m.squeeze(dim) |
|
return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim)) |
|
else: |
|
m = torch.max(value) |
|
sum_exp = torch.sum(torch.exp(value - m)) |
|
return m + torch.log(sum_exp) |
|
|
|
class VAEPretrainedModel(PreTrainedModel): |
|
def _init_weights(self, module): |
|
""" Initialize the weights """ |
|
pass |
|
|
|
class DAVAEModel(VAEPretrainedModel): |
|
config_class = PretrainedConfig |
|
def __init__(self, config:PretrainedConfig,*model_args, **model_kwargs): |
|
super().__init__(config=config) |
|
self.config = config |
|
self.vae_model = DAVAEModel.load_model(self.config) |
|
|
|
@classmethod |
|
def load_model(cls, config): |
|
encoder_config = BertConfig.from_dict(config.encoder) |
|
encoder_model = BertForLatentConnector(config=encoder_config, latent_size=config.latent_size) |
|
dec_config = TransfoXLConfig.from_dict(config.decoder) |
|
dec_config.latent_size = config.latent_size |
|
decoder_model = GPT2ModelForLatent(config=dec_config) |
|
vae_model = EncDecAAE(config,encoder_model, decoder_model, dec_config.latent_size, pad_token_id=50000) |
|
return vae_model |
|
|
|
def set_tokenizers(self,encoder_tokenizer,decoder_tokenizer): |
|
if not hasattr(self, 'encoder_tokenizer'): |
|
self.encoder_tokenizer = encoder_tokenizer |
|
if not hasattr(self, 'decoder_tokenizer'): |
|
self.decoder_tokenizer = decoder_tokenizer |
|
|
|
def simulate_batch(self,encoder_tokenizer,decoder_tokenizer, sent_inputs, prompt=None): |
|
self.set_tokenizers(encoder_tokenizer,decoder_tokenizer) |
|
|
|
latent_z = self.latent_code_from_text_batch(sent_inputs) |
|
text_analogy = self.text_from_latent_code_batch(latent_z,prompt=prompt) |
|
return text_analogy |
|
|
|
def latent_code_from_text_batch(self,texts): |
|
|
|
tokens_tensor_list = [] |
|
for text in texts: |
|
tokens = self.encoder_tokenizer.encode(text)[:510] |
|
tokens_tensor_list.append(torch.tensor([101]+tokens+[102])) |
|
|
|
coded = pad_sequence(tokens_tensor_list, batch_first=True, padding_value=0).long() |
|
device = next(self.vae_model.decoder.parameters()).device |
|
with torch.no_grad(): |
|
coded = coded.to(device) |
|
pooled_hidden_fea = self.vae_model.encoder(coded, attention_mask=(coded > 0).float())[1] |
|
mean, logvar = self.vae_model.encoder.linear(pooled_hidden_fea).chunk(2, -1) |
|
|
|
std = logvar.mul(0.5).exp() |
|
eps = torch.zeros_like(std).normal_() |
|
|
|
latent_z = mean + torch.mul(eps, std)*self.config.std_scale |
|
return latent_z |
|
def text_from_latent_code_batch(self,latent_z, prompt=None): |
|
|
|
device = next(self.vae_model.decoder.parameters()).device |
|
past = latent_z |
|
batch_size = latent_z.shape[0] |
|
bos_token = self.decoder_tokenizer.convert_tokens_to_ids(self.decoder_tokenizer.bos_token) |
|
end_token = self.decoder_tokenizer.convert_tokens_to_ids(self.decoder_tokenizer.eos_token) |
|
|
|
if prompt is not None: |
|
prompt = [[bos_token] + self.decoder_tokenizer.encode(text)[:-1] for text in prompt] |
|
else: |
|
prompt = [[bos_token]]*batch_size |
|
|
|
context_tokens_tensor = torch.tensor([[end_token]*self.config.max_out_length]*batch_size).to(device) |
|
context_length_tensor = torch.tensor([1]*batch_size).to(device) |
|
for i in range(batch_size): |
|
context_tokens_tensor[i,:len(prompt[i])] = torch.tensor(prompt[i]).long().to(device) |
|
context_length_tensor[i] = len(prompt[i]) |
|
|
|
out = sample_sequence_conditional_batch( |
|
model=self.vae_model.decoder, |
|
max_out_length= self.config.max_out_length, |
|
context_tokens_tensor=context_tokens_tensor, |
|
context_length_tensor=context_length_tensor, |
|
latent_z=latent_z, |
|
temperature=self.config.temperature, |
|
top_k=self.config.top_k, |
|
top_p=self.config.top_p, |
|
repetition_penalty=self.config.repetition_penalty, |
|
device=device |
|
) |
|
|
|
out_text = [] |
|
for i, tokens in enumerate(out): |
|
tokens = tokens[len(prompt[i]):] |
|
tokens = tokens[:tokens.index(end_token)] if end_token in tokens else tokens |
|
text = self.decoder_tokenizer.decode(tokens, clean_up_tokenization_spaces=True) |
|
out_text.append(filter_noise(text)) |
|
return out_text |
|
class EncDecAAE(nn.Module): |
|
"""Adversarial Auto-Encoder""" |
|
def __init__(self,config, encoder, decoder, latent_size, pad_token_id): |
|
super(EncDecAAE, self).__init__() |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
self.config = config |
|
self.pad_token_id = pad_token_id |
|
self.Disc = nn.Sequential(nn.Linear(latent_size, 4*latent_size), nn.ReLU(), |
|
nn.Linear(4*latent_size, 1)) |
|
|
|
loc = torch.zeros(latent_size) |
|
scale = torch.ones(latent_size) |
|
self.prior = torch.distributions.normal.Normal(loc, scale) |
|
|
|
def connect(self, bert_fea, nsamples=1, fb_mode=0): |
|
""" |
|
Returns: Tensor1, Tensor2 |
|
Tensor1: the tensor latent z with shape [batch, nsamples, nz] |
|
Tensor2: the tenor of KL for each x with shape [batch] |
|
""" |
|
|
|
|
|
mean, logvar = self.encoder.linear(bert_fea).chunk(2, -1) |
|
z = self.reparameterize(mean, logvar, nsamples) |
|
if fb_mode == 0: |
|
KL = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) |
|
elif fb_mode == 1: |
|
kl_loss = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1) |
|
kl_mask = (kl_loss > self.config.dim_target_kl).float() |
|
KL = (kl_mask * kl_loss).sum(dim=1) |
|
|
|
return z, KL |
|
|
|
def connect_deterministic(self, bert_fea, nsamples=1): |
|
""" |
|
Returns: Tensor1, Tensor2 |
|
Tensor1: the tensor latent z with shape [batch, nsamples, nz] |
|
Tensor2: the tenor of KL for each x with shape [batch] |
|
""" |
|
|
|
|
|
|
|
mean, logvar = self.encoder.linear(bert_fea).chunk(2, -1) |
|
logvar = torch.zeros_like(logvar) |
|
z = self.reparameterize(mean, logvar, nsamples) |
|
KL = 0.5 * (mean.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) |
|
|
|
return z, KL |
|
|
|
def reparameterize(self, mu, logvar, nsamples=1): |
|
"""sample from posterior Gaussian family |
|
Args: |
|
mu: Tensor |
|
Mean of gaussian distribution with shape (batch, nz) |
|
logvar: Tensor |
|
logvar of gaussian distibution with shape (batch, nz) |
|
Returns: Tensor |
|
Sampled z with shape (batch, nsamples, nz) |
|
""" |
|
batch_size, nz = mu.size() |
|
std = logvar.mul(0.5).exp() |
|
|
|
mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz) |
|
std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz) |
|
|
|
eps = torch.zeros_like(std_expd).normal_() |
|
|
|
return mu_expd + torch.mul(eps, std_expd) |
|
|
|
def loss_adv(self, z): |
|
zn = torch.randn_like(z) |
|
zeros = torch.zeros(len(z), 1, device=z.device).half() |
|
ones = torch.ones(len(z), 1, device=z.device).half() |
|
|
|
loss_d = F.binary_cross_entropy_with_logits(self.Disc(z.detach().half()), zeros) + \ |
|
F.binary_cross_entropy_with_logits(self.Disc(zn.half()), ones) |
|
loss_g = F.binary_cross_entropy_with_logits(self.Disc(z.half()), ones) |
|
return loss_d, loss_g |
|
|
|
def forward(self, inputs, labels, beta=0.0, iw=None, fb_mode=0, emb_noise=None): |
|
attention_mask = (inputs > 0).float() |
|
reconstrution_mask = (labels != self.pad_token_id).float() |
|
sent_length = torch.sum(reconstrution_mask, dim=1) |
|
|
|
outputs = self.encoder(inputs, attention_mask, emb_noise=emb_noise) |
|
pooled_hidden_fea = outputs[1] |
|
|
|
seq_length = labels.size(1) |
|
dec_attn_mask = self.decoder.get_attn_mask(seq_length).to(labels.device) |
|
|
|
if fb_mode in [0,1]: |
|
latent_z, loss_kl = self.connect(pooled_hidden_fea, fb_mode=fb_mode) |
|
latent_z = latent_z.squeeze(1) |
|
outputs = self.decoder(input_ids=labels, attention_mask=dec_attn_mask, latent_state=latent_z, labels=labels, label_ignore=self.pad_token_id) |
|
loss_rec = outputs[0] |
|
elif fb_mode==2: |
|
latent_z, loss_kl = self.connect_deterministic(pooled_hidden_fea) |
|
latent_z = latent_z.squeeze(1) |
|
outputs = self.decoder(input_ids=labels, attention_mask=dec_attn_mask, latent_state=latent_z, labels=labels, label_ignore=self.pad_token_id) |
|
loss_rec = outputs[0] |
|
|
|
if self.config.length_weighted_loss: |
|
loss = loss_rec / sent_length + beta * loss_kl |
|
else: |
|
loss = loss_rec + beta * loss_kl |
|
|
|
if iw!=None: |
|
total_loss = torch.sum(loss*iw)/torch.sum(iw) |
|
else: |
|
total_loss = torch.sum(loss) |
|
return (loss_rec/sent_length).mean(), loss_kl.mean(), total_loss |
|
|
|
|