|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch BERT model. """ |
|
|
|
from __future__ import absolute_import, division, print_function, unicode_literals |
|
|
|
import json |
|
import logging |
|
import math |
|
import os |
|
import sys |
|
from io import open |
|
|
|
import pdb |
|
|
|
import torch |
|
from torch import nn |
|
from transformers import BertConfig,BertPreTrainedModel |
|
from transformers.models.bert.modeling_bert import BertEmbeddings,BertEncoder,BertPooler |
|
|
|
|
|
class BertForLatentConnector(BertPreTrainedModel): |
|
r""" |
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: |
|
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` |
|
Sequence of hidden-states at the output of the last layer of the model. |
|
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` |
|
Last layer hidden-state of the first token of the sequence (classification token) |
|
further processed by a Linear layer and a Tanh activation function. The Linear |
|
layer weights are trained from the next sentence prediction (classification) |
|
objective during Bert pretraining. This output is usually *not* a good summary |
|
of the semantic content of the input, you're often better with averaging or pooling |
|
the sequence of hidden-states for the whole input sequence. |
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) |
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) |
|
of shape ``(batch_size, sequence_length, hidden_size)``: |
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``) |
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: |
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. |
|
|
|
Examples:: |
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
model = BertModel.from_pretrained('bert-base-uncased') |
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 |
|
outputs = model(input_ids) |
|
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple |
|
|
|
""" |
|
def __init__(self, config, latent_size): |
|
super(BertForLatentConnector, self).__init__(config) |
|
|
|
self.embeddings = BertEmbeddings(config) |
|
self.encoder = BertEncoder(config) |
|
self.pooler = BertPooler(config) |
|
|
|
self.linear = nn.Linear(config.hidden_size, 2 * latent_size, bias=False) |
|
|
|
self.init_weights() |
|
|
|
def _resize_token_embeddings(self, new_num_tokens): |
|
old_embeddings = self.embeddings.word_embeddings |
|
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) |
|
self.embeddings.word_embeddings = new_embeddings |
|
return self.embeddings.word_embeddings |
|
|
|
def _prune_heads(self, heads_to_prune): |
|
""" Prunes heads of the model. |
|
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} |
|
See base class PreTrainedModel |
|
""" |
|
for layer, heads in heads_to_prune.items(): |
|
self.encoder.layer[layer].attention.prune_heads(heads) |
|
|
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, emb_noise=None): |
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids) |
|
if token_type_ids is None: |
|
token_type_ids = torch.zeros_like(input_ids) |
|
|
|
|
|
|
|
|
|
|
|
|
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) |
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
if head_mask is not None: |
|
if head_mask.dim() == 1: |
|
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) |
|
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) |
|
elif head_mask.dim() == 2: |
|
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) |
|
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) |
|
else: |
|
head_mask = [None] * self.config.num_hidden_layers |
|
|
|
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids) |
|
|
|
if emb_noise is not None: |
|
embedding_output = embedding_output + emb_noise(embedding_output).to(embedding_output.dtype) |
|
|
|
encoder_outputs = self.encoder(embedding_output, |
|
extended_attention_mask, |
|
head_mask=head_mask) |
|
sequence_output = encoder_outputs[0] |
|
pooled_output = self.pooler(sequence_output) |
|
|
|
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] |
|
return outputs |
|
|