ME2-BERT / modeling_me2bert.py
lorenzozan's picture
Update modeling_me2bert.py
46a120f verified
raw
history blame
7.28 kB
from transformers import PreTrainedModel
from transformers import AutoModel
import torch
from torch.autograd import Function
from configuration_me2bert import ME2BertConfig
class ReverseLayerF(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.alpha
return output, None
class FFClassifier(torch.nn.Module):
def __init__(self, input_dim, hidden_dim, n_classes, dropout=0.0):
super(FFClassifier, self).__init__()
self.model = torch.nn.Sequential(
torch.nn.Linear(input_dim, hidden_dim),
torch.nn.BatchNorm1d(hidden_dim), torch.nn.ReLU(True),
torch.nn.Dropout(dropout), torch.nn.Linear(hidden_dim, n_classes))
def forward(self, x):
return self.model(x)
class Encoder(torch.nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(Encoder, self).__init__()
self.fc1 = torch.nn.Linear(input_dim, hidden_dim, bias=True)
self.fc2 = torch.nn.Linear(hidden_dim, latent_dim, bias=True)
self.prelu = torch.nn.PReLU()
def forward(self, x):
x = self.prelu(self.fc1(x))
x = self.fc2(x)
return x
class Decoder(torch.nn.Module):
def __init__(self, latent_dim, hidden_dim, output_dim):
super(Decoder, self).__init__()
self.fc1 = torch.nn.Linear(latent_dim, hidden_dim, bias=True)
self.fc2 = torch.nn.Linear(hidden_dim, output_dim, bias=True)
self.prelu = torch.nn.PReLU()
def forward(self, x):
x = self.prelu(self.fc1(x))
return self.fc2(x)
class AutoEncoder(torch.nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(AutoEncoder, self).__init__()
self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
self.layer_norm = torch.nn.LayerNorm(latent_dim)
self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
def forward(self, x):
encoded = self.encoder(x)
encoded = self.layer_norm(encoded)
decoded = self.decoder(encoded)
decoded = decoded
return encoded, decoded
class GatedCombination(torch.nn.Module):
def __init__(self, embedding_dim):
super(GatedCombination, self).__init__()
self.embedding_dim = embedding_dim
self.forget_gate = torch.nn.Linear(embedding_dim, embedding_dim)
self.input_gate = torch.nn.Linear(embedding_dim, embedding_dim)
self.output_gate = torch.nn.Linear(embedding_dim, embedding_dim)
self.sigmoid = torch.nn.Sigmoid()
self.tanh = torch.nn.Tanh()
def forward(self, frozen_output, finetuned_output):
forget_gate = self.sigmoid(self.forget_gate(frozen_output))
input_gate = self.sigmoid(self.input_gate(finetuned_output))
combined = forget_gate * frozen_output + input_gate * finetuned_output
output_gate = self.sigmoid(self.output_gate(combined))
gated_output = output_gate * self.tanh(combined)
return gated_output
class ME2BertModel(PreTrainedModel):
config_class = ME2BertConfig
base_model_prefix = "me2bert"
def __init__(
self,
config: ME2BertConfig = None):
if config is None:
config = ME2BertConfig()
super().__init__(config)
self.n_mf_classes = 5
self.n_domain_classes = 2
pretrained_model_name = config.pretrained_model_name
self.has_gate = config.has_gate
self.has_trans = config.has_trans
self.emotion_labels = [0, 0, 0, 0, 0]
self.feature = AutoModel.from_pretrained(pretrained_model_name)
self.bert_frozen = AutoModel.from_pretrained(pretrained_model_name)
for param in self.bert_frozen.parameters():
param.requires_grad = False
self.embedding_dim = self.feature.config.hidden_size
latent_dim = 128
self.emotion_dim = 5
self.gated_combination = (
GatedCombination(embedding_dim=self.embedding_dim)
)
self.trans_module = (
AutoEncoder(self.embedding_dim, 256, latent_dim))
initial_dim = self.embedding_dim + self.n_domain_classes + self.emotion_dim
self.mf_classifier = FFClassifier(
initial_dim, latent_dim, self.n_mf_classes, .0
)
self.domain_classifier = FFClassifier(
self.embedding_dim, latent_dim, self.n_domain_classes,
)
def gen_feature_embeddings(self, input_ids, attention_mask):
feature = self.feature(input_ids=input_ids, attention_mask=attention_mask)
return feature.last_hidden_state, feature.pooler_output
def forward(self,
input_ids,
attention_mask, return_dict=False):
_, pooler_output = self.gen_feature_embeddings(
input_ids, attention_mask)
with torch.no_grad():
frozen_output = self.bert_frozen(input_ids=input_ids, attention_mask=attention_mask)
frozen_output = frozen_output.pooler_output
device = pooler_output.device
rec_embeddings = None
if self.has_trans:
rec_embeddings = pooler_output
_, pooler_output = self.trans_module(rec_embeddings)
if self.has_gate:
gated_output = self.gated_combination(frozen_output, pooler_output)
else:
gated_output = pooler_output
else:
gated_output = pooler_output
domain_labels = torch.zeros(gated_output.shape[0]).long().to(device)
domain_feature = torch.nn.functional.one_hot(
domain_labels, num_classes=self.n_domain_classes).squeeze(1)
emotion_features = None
if self.emotion_labels is not None:
if isinstance(self.emotion_labels, list):
emotion_tensor = torch.tensor(self.emotion_labels, dtype=torch.float32)
emotion_features = emotion_tensor.repeat(gated_output.shape[0], 1)
else:
emotion_features = torch.nn.functional.one_hot(
self.emotion_labels.long(), num_classes=self.emotion_dim
).squeeze(1)
if emotion_features is not None:
emotion_features = emotion_features[:gated_output.shape[0], :]
class_output = torch.cat((gated_output, domain_feature, emotion_features), dim=1)
else:
emotion_features = torch.zeros(gated_output.shape[0], self.emotion_dim).to(device)
class_output = torch.cat((gated_output, domain_feature, emotion_features), dim=1)
class_output = torch.sigmoid(self.mf_classifier(class_output))
if return_dict:
mft_dimensions = [
'CH',
'FC',
'LB',
'AS',
'PD'
]
result_list = []
for i in range(class_output.shape[0]):
row_scores = [round(score.item(), 5) for score in class_output[i]]
row_dict = dict(zip(mft_dimensions, row_scores))
result_list.append(row_dict)
return result_list
return class_output