denosent-bert-base / model.py
Singhoo's picture
Upload model
644b4b6 verified
from transformers import AutoTokenizer, BertForMaskedLM
from transformers.models.bert.modeling_bert import BertForMaskedLM
from transformers.modeling_outputs import TokenClassifierOutput
from transformers import PreTrainedModel
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss, TransformerDecoder, TransformerDecoderLayer
from typing import Optional
import wandb
import numpy as np
class DenoSentModel(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.pooler = config.pooler
self.sent_embedding_projector = nn.Linear(config.hidden_size, config.hidden_size)
self.decoder = TransformerDecoder(TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.decoder_num_heads, batch_first=True, dropout=0.1), num_layers=config.decoder_num_layers)
self.decoder_noise_dropout = nn.Dropout(config.decoder_noise_dropout)
self.sim = nn.CosineSimilarity(dim=-1)
self.init_weights()
self.tokenizer = AutoTokenizer.from_pretrained(config.encoder_name_or_path)
self.encoder = BertForMaskedLM.from_pretrained(config.encoder_name_or_path)
self.prediction_head = self.encoder.cls
self.encoder = self.encoder.bert
self.post_init()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def encode(self, sentences, batch_size=32, **kwargs):
""" Returns a list of embeddings for the given sentences.
Args:
sentences (`List[str]`): List of sentences to encode
batch_size (`int`): Batch size for the encoding
Returns:
`List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences
"""
self.eval()
all_embeddings = []
length_sorted_idx = np.argsort([len(sen) for sen in sentences])
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
if self.config.pooler == 'mask':
prompt_length = len(self.tokenizer(self.config.prompt_format, add_special_tokens=False)['input_ids'])
sentences_sorted = self.tokenizer.batch_decode(self.tokenizer(sentences_sorted, padding=True, truncation=True, max_length=self.config.max_length, return_tensors='pt').input_ids, skip_special_tokens=True)
sentences_sorted = [self.config.prompt_format.replace('[X]', s).replace('[MASK]', self.tokenizer.mask_token) for s in sentences_sorted]
for start_index in range(0, len(sentences), batch_size):
sentences_batch = sentences_sorted[start_index:start_index+batch_size]
inputs = self.tokenizer(sentences_batch, padding='max_length', truncation=True, return_tensors="pt", max_length=self.config.max_length+prompt_length)
inputs = {k: v.to(self.device) for k,v in inputs.items()}
with torch.no_grad():
encoder_outputs = self.encoder(**inputs, output_hidden_states=True, output_attentions=True, return_dict=True)
last_hidden_state = encoder_outputs.last_hidden_state
if self.config.pooler == 'cls':
embeddings = last_hidden_state[:, 0, :]
elif self.config.pooler == 'mean':
embeddings = (last_hidden_state * inputs['attention_mask'].unsqueeze(-1)).sum(1) / inputs['attention_mask'].sum(-1).unsqueeze(-1)
elif self.pooler == 'mask':
embeddings = last_hidden_state[inputs['input_ids'] == self.tokenizer.mask_token_id]
else:
raise NotImplementedError()
all_embeddings.extend(embeddings.cpu().numpy())
all_embeddings = torch.tensor(np.array([all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]))
return all_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
positive_input_ids: Optional[torch.LongTensor] = None,
positive_attention_mask: Optional[torch.LongTensor] = None,
negative_input_ids: Optional[torch.LongTensor] = None,
negative_attention_mask: Optional[torch.LongTensor] = None,
global_step: Optional[int] = None,
max_steps: Optional[int] = None,
):
batch_size = input_ids.size(0)
if negative_input_ids is not None:
encoder_input_ids = torch.cat([input_ids, positive_input_ids, negative_input_ids], dim=0).to(self.device)
encoder_attention_mask = torch.cat([attention_mask, positive_attention_mask, negative_attention_mask], dim=0).to(self.device)
elif positive_input_ids is not None:
encoder_input_ids = torch.cat([input_ids, positive_input_ids], dim=0).to(self.device)
encoder_attention_mask = torch.cat([attention_mask, positive_attention_mask], dim=0).to(self.device)
elif self.config.do_contrastive:
encoder_input_ids = torch.cat([input_ids, input_ids], dim=0).to(self.device)
encoder_attention_mask = torch.cat([attention_mask, attention_mask], dim=0).to(self.device)
elif self.config.do_generative and not self.config.do_contrastive:
encoder_input_ids = input_ids.to(self.device)
encoder_attention_mask = attention_mask.to(self.device)
else:
raise NotImplementedError()
encoder_outputs = self.encoder(input_ids=encoder_input_ids, attention_mask=encoder_attention_mask, return_dict=True, output_hidden_states=True, output_attentions=True)
if self.pooler == 'cls':
sent_embedding = encoder_outputs.last_hidden_state[:, 0, :]
elif self.pooler == 'mean':
sent_embedding = ((encoder_outputs.last_hidden_state * encoder_attention_mask.unsqueeze(-1)).sum(1) / encoder_attention_mask.sum(-1).unsqueeze(-1))
elif self.pooler == 'mask':
sent_embedding = encoder_outputs.last_hidden_state[encoder_input_ids == self.tokenizer.mask_token_id]
else:
raise NotImplementedError()
sent_embedding = sent_embedding.unsqueeze(1)
sent_embedding = self.sent_embedding_projector(sent_embedding)
if self.config.do_generative:
if positive_input_ids is not None:
tgt = encoder_outputs.hidden_states[0][batch_size:2*batch_size].detach()
tgt_key_padding_mask = (positive_input_ids == self.tokenizer.pad_token_id)
labels = positive_input_ids
else:
tgt = encoder_outputs.hidden_states[0][:batch_size].detach()
tgt_key_padding_mask = (input_ids == self.tokenizer.pad_token_id)
labels = input_ids
tgt = self.decoder_noise_dropout(tgt)
decoder_outputs = self.decoder(tgt=tgt, memory=sent_embedding[:batch_size], tgt_mask=None, tgt_key_padding_mask=tgt_key_padding_mask)
logits = self.prediction_head(decoder_outputs)
loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
generative_loss = loss_fct(logits.view(-1, self.encoder.config.vocab_size), labels.view(-1))
wandb.log({'train/generative_loss': generative_loss})
if self.config.do_contrastive:
positive_sim = self.sim(sent_embedding[:batch_size], sent_embedding[batch_size:2*batch_size].transpose(0, 1))
cos_sim = positive_sim
if negative_attention_mask is not None:
negative_sim = self.sim(sent_embedding[:batch_size], sent_embedding[2*batch_size:].transpose(0, 1))
cos_sim = torch.cat([positive_sim, negative_sim], dim=1)
cos_sim = cos_sim / self.config.contrastive_temp
contrastive_labels = torch.arange(batch_size, dtype=torch.long, device=self.device)
contrastive_loss = nn.CrossEntropyLoss()(cos_sim, contrastive_labels)
wandb.log({'train/contrastive_loss': contrastive_loss.item()})
logits = None
loss = 0
if self.config.do_contrastive:
loss += self.config.contrastive_weight * contrastive_loss
if self.config.do_generative:
loss += self.config.generative_weight * generative_loss
wandb.log({'train/loss': loss})
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)