|
from transformers import PreTrainedModel |
|
from transformers import AutoModel |
|
import torch |
|
from torch.autograd import Function |
|
from configuration_me2bert import ME2BertConfig |
|
|
|
class ReverseLayerF(Function): |
|
|
|
@staticmethod |
|
def forward(ctx, x, alpha): |
|
ctx.alpha = alpha |
|
|
|
return x.view_as(x) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
output = grad_output.neg() * ctx.alpha |
|
|
|
return output, None |
|
|
|
|
|
class FFClassifier(torch.nn.Module): |
|
|
|
def __init__(self, input_dim, hidden_dim, n_classes, dropout=0.0): |
|
super(FFClassifier, self).__init__() |
|
|
|
self.model = torch.nn.Sequential( |
|
torch.nn.Linear(input_dim, hidden_dim), |
|
torch.nn.BatchNorm1d(hidden_dim), torch.nn.ReLU(True), |
|
torch.nn.Dropout(dropout), torch.nn.Linear(hidden_dim, n_classes)) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
class Encoder(torch.nn.Module): |
|
|
|
def __init__(self, input_dim, hidden_dim, latent_dim): |
|
super(Encoder, self).__init__() |
|
self.fc1 = torch.nn.Linear(input_dim, hidden_dim, bias=True) |
|
self.fc2 = torch.nn.Linear(hidden_dim, latent_dim, bias=True) |
|
self.prelu = torch.nn.PReLU() |
|
|
|
def forward(self, x): |
|
x = self.prelu(self.fc1(x)) |
|
x = self.fc2(x) |
|
return x |
|
|
|
|
|
class Decoder(torch.nn.Module): |
|
def __init__(self, latent_dim, hidden_dim, output_dim): |
|
super(Decoder, self).__init__() |
|
self.fc1 = torch.nn.Linear(latent_dim, hidden_dim, bias=True) |
|
self.fc2 = torch.nn.Linear(hidden_dim, output_dim, bias=True) |
|
self.prelu = torch.nn.PReLU() |
|
|
|
def forward(self, x): |
|
x = self.prelu(self.fc1(x)) |
|
return self.fc2(x) |
|
|
|
|
|
class AutoEncoder(torch.nn.Module): |
|
def __init__(self, input_dim, hidden_dim, latent_dim): |
|
super(AutoEncoder, self).__init__() |
|
self.encoder = Encoder(input_dim, hidden_dim, latent_dim) |
|
self.layer_norm = torch.nn.LayerNorm(latent_dim) |
|
self.decoder = Decoder(latent_dim, hidden_dim, input_dim) |
|
|
|
def forward(self, x): |
|
encoded = self.encoder(x) |
|
encoded = self.layer_norm(encoded) |
|
decoded = self.decoder(encoded) |
|
decoded = decoded |
|
return encoded, decoded |
|
|
|
|
|
class GatedCombination(torch.nn.Module): |
|
def __init__(self, embedding_dim): |
|
super(GatedCombination, self).__init__() |
|
self.embedding_dim = embedding_dim |
|
|
|
self.forget_gate = torch.nn.Linear(embedding_dim, embedding_dim) |
|
self.input_gate = torch.nn.Linear(embedding_dim, embedding_dim) |
|
self.output_gate = torch.nn.Linear(embedding_dim, embedding_dim) |
|
|
|
self.sigmoid = torch.nn.Sigmoid() |
|
self.tanh = torch.nn.Tanh() |
|
|
|
def forward(self, frozen_output, finetuned_output): |
|
forget_gate = self.sigmoid(self.forget_gate(frozen_output)) |
|
input_gate = self.sigmoid(self.input_gate(finetuned_output)) |
|
|
|
combined = forget_gate * frozen_output + input_gate * finetuned_output |
|
|
|
output_gate = self.sigmoid(self.output_gate(combined)) |
|
|
|
gated_output = output_gate * self.tanh(combined) |
|
|
|
return gated_output |
|
|
|
|
|
class ME2BertModel(PreTrainedModel): |
|
config_class = ME2BertConfig |
|
base_model_prefix = "me2bert" |
|
def __init__( |
|
self, |
|
config: ME2BertConfig = None): |
|
if config is None: |
|
config = ME2BertConfig() |
|
|
|
super().__init__(config) |
|
self.n_mf_classes = 5 |
|
self.n_domain_classes = 2 |
|
pretrained_model_name = config.pretrained_model_name |
|
self.has_gate = config.has_gate |
|
self.has_trans = config.has_trans |
|
self.emotion_labels = [0, 0, 0, 0, 0] |
|
self.feature = AutoModel.from_pretrained(pretrained_model_name) |
|
self.bert_frozen = AutoModel.from_pretrained(pretrained_model_name) |
|
|
|
for param in self.bert_frozen.parameters(): |
|
param.requires_grad = False |
|
|
|
self.embedding_dim = self.feature.config.hidden_size |
|
latent_dim = 128 |
|
self.emotion_dim = 5 |
|
|
|
self.gated_combination = ( |
|
GatedCombination(embedding_dim=self.embedding_dim) |
|
) |
|
|
|
self.trans_module = ( |
|
AutoEncoder(self.embedding_dim, 256, latent_dim)) |
|
|
|
initial_dim = self.embedding_dim + self.n_domain_classes + self.emotion_dim |
|
|
|
self.mf_classifier = FFClassifier( |
|
initial_dim, latent_dim, self.n_mf_classes, .0 |
|
) |
|
|
|
self.domain_classifier = FFClassifier( |
|
self.embedding_dim, latent_dim, self.n_domain_classes, |
|
|
|
) |
|
|
|
def gen_feature_embeddings(self, input_ids, attention_mask): |
|
feature = self.feature(input_ids=input_ids, attention_mask=attention_mask) |
|
return feature.last_hidden_state, feature.pooler_output |
|
|
|
def forward(self, |
|
input_ids, |
|
attention_mask, return_dict=False): |
|
|
|
_, pooler_output = self.gen_feature_embeddings( |
|
input_ids, attention_mask) |
|
|
|
with torch.no_grad(): |
|
frozen_output = self.bert_frozen(input_ids=input_ids, attention_mask=attention_mask) |
|
|
|
frozen_output = frozen_output.pooler_output |
|
|
|
device = pooler_output.device |
|
rec_embeddings = None |
|
if self.has_trans: |
|
rec_embeddings = pooler_output |
|
_, pooler_output = self.trans_module(rec_embeddings) |
|
if self.has_gate: |
|
gated_output = self.gated_combination(frozen_output, pooler_output) |
|
else: |
|
gated_output = pooler_output |
|
else: |
|
gated_output = pooler_output |
|
|
|
domain_labels = torch.zeros(gated_output.shape[0]).long().to(device) |
|
domain_feature = torch.nn.functional.one_hot( |
|
domain_labels, num_classes=self.n_domain_classes).squeeze(1) |
|
|
|
emotion_features = None |
|
if self.emotion_labels is not None: |
|
if isinstance(self.emotion_labels, list): |
|
emotion_tensor = torch.tensor(self.emotion_labels, dtype=torch.float32) |
|
emotion_features = emotion_tensor.repeat(gated_output.shape[0], 1) |
|
else: |
|
emotion_features = torch.nn.functional.one_hot( |
|
self.emotion_labels.long(), num_classes=self.emotion_dim |
|
).squeeze(1) |
|
|
|
if emotion_features is not None: |
|
emotion_features = emotion_features[:gated_output.shape[0], :] |
|
class_output = torch.cat((gated_output, domain_feature, emotion_features), dim=1) |
|
|
|
else: |
|
emotion_features = torch.zeros(gated_output.shape[0], self.emotion_dim).to(device) |
|
class_output = torch.cat((gated_output, domain_feature, emotion_features), dim=1) |
|
|
|
class_output = torch.sigmoid(self.mf_classifier(class_output)) |
|
if return_dict: |
|
mft_dimensions = [ |
|
'CH', |
|
'FC', |
|
'LB', |
|
'AS', |
|
'PD' |
|
] |
|
|
|
result_list = [] |
|
for i in range(class_output.shape[0]): |
|
row_scores = [round(score.item(), 5) for score in class_output[i]] |
|
row_dict = dict(zip(mft_dimensions, row_scores)) |
|
result_list.append(row_dict) |
|
return result_list |
|
|
|
return class_output |
|
|