malmukhtar's picture
Initial commit
e9dc5bd
raw
history blame
4.16 kB
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import AlbertForSequenceClassification, AlbertTokenizer
import torch
import torch.nn.functional as F
import numpy as np
class AlbertForMultilabelSequenceClassification(AlbertForSequenceClassification):
def __init__(self, config):
super().__init__(config)
def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.albert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = torch.nn.BCEWithLogitsLoss()
loss = loss_fct(logits.view(-1, self.num_labels),
labels.float().view(-1, self.num_labels))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions)
class Model:
def __init__(self):
self.device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu")
self.labels = ['Accessibility', 'Non-accessibility']
self.tokenizer = AlbertTokenizer.from_pretrained(
'albert-base-v2', do_lower_case=True)
classifier = AlbertForMultilabelSequenceClassification.from_pretrained(
'albert-base-v2',
output_attentions=False,
output_hidden_states=False,
num_labels=2
)
classifier.load_state_dict(
torch.load("assets/pytorch_model.bin", map_location=self.device))
classifier = classifier.eval()
self.classifier = classifier.to(self.device)
def predict(self, text):
encoded_text = self.tokenizer.encode_plus(
text,
max_length=30,
add_special_tokens=True,
return_token_type_ids=False,
padding='longest',
return_attention_mask=True,
return_tensors="pt",
truncation=True,
)
input_ids = encoded_text["input_ids"].to(self.device)
attention_mask = encoded_text["attention_mask"].to(self.device)
with torch.no_grad():
probabilities = self.classifier(input_ids, attention_mask)
prediction = F.softmax(probabilities.logits,
dim=1).cpu().numpy().flatten().max()
prediction_index = np.where(F.softmax(probabilities.logits,
dim=1).cpu().numpy() == prediction)[1][0]
label = self.labels[prediction_index]
all_predictions = F.softmax(
probabilities.logits, dim=1).cpu().numpy().flatten()
accessibility_prediction = all_predictions[0]
nonaccessibility_prediction = all_predictions[1]
return (accessibility_prediction, nonaccessibility_prediction)
model = Model()
# model.predict("this is an impsorvement")
def get_model():
return model