Spaces:
Build error
Build error
#!g1.1 | |
import torch | |
import torch.nn as nn | |
from transformers import AutoTokenizer, AutoModel | |
import copy | |
class DL_category(nn.Module): | |
def __init__(self): | |
super(DL_category, self).__init__() | |
self.lin1 = nn.Linear(256, 64) | |
nn.init.xavier_uniform_(self.lin1.weight) | |
self.lin2 = nn.Linear(64, 5) | |
nn.init.xavier_uniform_(self.lin2.weight) | |
def forward(self, x): | |
x = torch.relu(self.lin1(x)) | |
x = self.lin2(x) | |
return x | |
class DL_sentiment(nn.Module): | |
def __init__(self): | |
super(DL_sentiment, self).__init__() | |
self.lin1 = nn.Linear(256, 64) | |
nn.init.xavier_uniform_(self.lin1.weight) | |
self.lin2 = nn.Linear(64, 1, bias=False) | |
nn.init.xavier_uniform_(self.lin2.weight) | |
def forward(self, x): | |
x = torch.relu(self.lin1(x)) | |
x = self.lin2(x) | |
return x | |
def mean_pooling(model_output, attention_mask): | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float() | |
sum_embeddings = torch.sum(model_output * input_mask_expanded, 1) | |
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
return sum_embeddings / sum_mask | |
class Union_model(nn.Module): | |
def __init__(self, bert_model): | |
super(Union_model, self).__init__() | |
bert_model = bert_model | |
for name, param in bert_model.named_parameters(): | |
param.requires_grad = False | |
bert_model.pooler = nn.Sequential( | |
nn.Linear(in_features=768, out_features=256) | |
) | |
bert_model = bert_model.to('cpu') | |
# print(bert_model.parameters) | |
self.DL_cat = DL_category() | |
self.DL_sent = DL_sentiment() | |
def forward(self, input): | |
output = bert_model(**input) | |
output = output.pooler_output | |
output = mean_pooling(output, input['attention_mask']) | |
return self.DL_cat(output), self.DL_sent(output) | |
class LogisticCumulativeLink(nn.Module): | |
""" | |
Converts a single number to the proportional odds of belonging to a class. | |
Parameters | |
---------- | |
num_classes : int | |
Number of ordered classes to partition the odds into. | |
init_cutpoints : str (default='ordered') | |
How to initialize the cutpoints of the model. Valid values are | |
- ordered : cutpoints are initialized to halfway between each class. | |
- random : cutpoints are initialized with random values. | |
""" | |
def __init__(self, num_classes: int, | |
init_cutpoints: str = 'ordered') -> None: | |
assert num_classes > 2, ( | |
'Only use this model if you have 3 or more classes' | |
) | |
super().__init__() | |
self.num_classes = num_classes | |
self.init_cutpoints = init_cutpoints | |
if init_cutpoints == 'ordered': | |
num_cutpoints = self.num_classes - 1 | |
cutpoints = torch.arange(num_cutpoints).float() - num_cutpoints / 2 | |
self.cutpoints = nn.Parameter(cutpoints) | |
elif init_cutpoints == 'random': | |
cutpoints = torch.rand(self.num_classes - 1).sort()[0] | |
self.cutpoints = nn.Parameter(cutpoints) | |
else: | |
raise ValueError(f'{init_cutpoints} is not a valid init_cutpoints ' | |
f'type') | |
def forward(self, X: torch.Tensor) -> torch.Tensor: | |
""" | |
Equation (11) from | |
"On the consistency of ordinal regression methods", Pedregosa et. al. | |
""" | |
sigmoids = torch.sigmoid(self.cutpoints - X) | |
link_mat = sigmoids[:, 1:] - sigmoids[:, :-1] | |
link_mat = torch.cat(( | |
sigmoids[:, [0]], | |
link_mat, | |
(1 - sigmoids[:, [-1]]) | |
), | |
dim=1 | |
) | |
return link_mat | |
class CustomOrdinalLogisticModel(nn.Module): | |
def __init__(self, predictor: nn.Module, num_classes: int, | |
init_cutpoints: str = 'ordered') -> None: | |
super().__init__() | |
self.num_classes = num_classes | |
self.predictor = copy.deepcopy(predictor) | |
self.link = LogisticCumulativeLink(self.num_classes, | |
init_cutpoints=init_cutpoints) | |
def forward(self, *args, **kwargs) -> torch.Tensor: | |
cat, sent = self.predictor(*args, **kwargs) | |
return cat, self.link(sent) | |
tokenizer = AutoTokenizer.from_pretrained('blanchefort/rubert-base-cased-sentiment-rusentiment') | |
bert_model = AutoModel.from_pretrained('blanchefort/rubert-base-cased-sentiment-rusentiment', | |
output_hidden_states=True).to('cpu') | |
bert_model.pooler = nn.Sequential( | |
nn.Linear(in_features=768, out_features=256) | |
) | |
model = CustomOrdinalLogisticModel(Union_model(bert_model), 3).to('cpu') | |
model.load_state_dict(torch.load('best_model_heads.pth', map_location='cpu'), strict=False) | |
bert_model.load_state_dict(torch.load('best_model_bert.pth', map_location='cpu')) | |
def inference(input_data): | |
tokenized = tokenizer(input_data['sentence']) | |
input_ids = torch.LongTensor(tokenized['input_ids']).unsqueeze(0).to('cpu') | |
attention_mask = torch.IntTensor(tokenized['attention_mask']).unsqueeze(0).to('cpu') | |
model.eval() | |
return dict(answer=model({'input_ids': input_ids, 'attention_mask': attention_mask})) | |