|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
from transformers import AlbertForSequenceClassification, AlbertTokenizer |
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
|
|
|
|
class AlbertForMultilabelSequenceClassification(AlbertForSequenceClassification): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
def forward(self, |
|
input_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
labels=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None): |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.albert(input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict) |
|
|
|
pooled_output = outputs[1] |
|
pooled_output = self.dropout(pooled_output) |
|
logits = self.classifier(pooled_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = torch.nn.BCEWithLogitsLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), |
|
labels.float().view(-1, self.num_labels)) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutput(loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions) |
|
|
|
|
|
class Model: |
|
def __init__(self): |
|
|
|
self.device = torch.device( |
|
"cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
self.labels = ['Accessibility', 'Non-accessibility'] |
|
self.tokenizer = AlbertTokenizer.from_pretrained( |
|
'albert-base-v2', do_lower_case=True) |
|
classifier = AlbertForMultilabelSequenceClassification.from_pretrained( |
|
'albert-base-v2', |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
num_labels=2 |
|
) |
|
|
|
classifier.load_state_dict( |
|
torch.load("assets/pytorch_model.bin", map_location=self.device)) |
|
classifier = classifier.eval() |
|
self.classifier = classifier.to(self.device) |
|
|
|
def predict(self, text): |
|
encoded_text = self.tokenizer.encode_plus( |
|
text, |
|
max_length=30, |
|
add_special_tokens=True, |
|
return_token_type_ids=False, |
|
padding='longest', |
|
return_attention_mask=True, |
|
return_tensors="pt", |
|
truncation=True, |
|
|
|
) |
|
|
|
input_ids = encoded_text["input_ids"].to(self.device) |
|
attention_mask = encoded_text["attention_mask"].to(self.device) |
|
|
|
with torch.no_grad(): |
|
probabilities = self.classifier(input_ids, attention_mask) |
|
|
|
prediction = F.softmax(probabilities.logits, |
|
dim=1).cpu().numpy().flatten().max() |
|
prediction_index = np.where(F.softmax(probabilities.logits, |
|
dim=1).cpu().numpy() == prediction)[1][0] |
|
label = self.labels[prediction_index] |
|
|
|
all_predictions = F.softmax( |
|
probabilities.logits, dim=1).cpu().numpy().flatten() |
|
|
|
accessibility_prediction = all_predictions[0] |
|
nonaccessibility_prediction = all_predictions[1] |
|
|
|
return (accessibility_prediction, nonaccessibility_prediction) |
|
|
|
|
|
model = Model() |
|
|
|
|
|
|
|
def get_model(): |
|
return model |
|
|