gauneg's picture
commit files to HF hub
c490f2d
import torch
class DecodeAndEvaluate:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.sentiment2id = {'negative': 3, 'neutral': 4, 'positive': 5}
self.id2sentiment = {v:k for k, v in self.sentiment2id.items()}
def get_span_from_tags(self, tags, token_range, tok_type): ## tok_type 1=aspect, 2 for opinions
sel_spans = []
end_ind = -1
has_prev = False
start_ind = -1
for i in range(len(token_range)):
l,r = token_range[i]
if tags[l][l]!= tok_type:
if has_prev:
sel_spans.append([start_ind, end_ind])
start_ind = -1
end_ind= -1
has_prev = False
if tags[l][l] == tok_type and not has_prev:
start_ind = l
end_ind = r
has_prev = True
if tags[l][l] == tok_type and has_prev:
end_ind = r
has_prev = True
if has_prev:
sel_spans.append([start_ind, end_ind])
return sel_spans
## Corner cases where one sentiment span expresses over multiple sentiments
# and one aspect has multiple sentiments expressed on it
def find_triplet(self, tags, aspect_spans, opinion_spans):
triplets = []
for al, ar in aspect_spans:
for pl, pr in opinion_spans:
## get the overlapping indices
# we select such that tag[aspect_l :aspect_r+1, opi_l: opi_r]
# if opi>asp then lower triangular matrix starts being selected that is not annotated
# print(al, ar, pl, pr)
if al<=pl:
sent_tags = tags[al:ar+1, pl:pr+1]
flat_tags = sent_tags.reshape([-1])
flat_tags = torch.tensor([v.item() for v in flat_tags if v.item()>=0])
val = torch.mode(flat_tags).values.item()
if val > 0:
triplets.append([al, ar, pl, pr, val])
else: # In this case the aspect becomes column and sentiment becomes the row
# print(al, pl)
sent_tags = tags[pl:pr+1, al: ar+1]
# print(sent_tags)
flat_tags = sent_tags.reshape([-1])
flat_tags = torch.tensor([v.item() for v in flat_tags if v.item()>=0])
val = torch.mode(flat_tags).values.item()
if val>0:
triplets.append([al, ar, pl, pr, val])
return triplets
def decode_triplets(self, triplets, sent_tokens):
triplet_list = []
for alt, art, olt, ort, pol in triplets:
asp_toks = sent_tokens[alt:art+1]
op_toks = sent_tokens[olt: ort+1]
asp_string = self.tokenizer.decode(asp_toks)
op_string = self.tokenizer.decode(op_toks)
if pol in [3, 4, 5]:
sentiment_pol = self.id2sentiment[pol] #.get(pol, "inconsistent")
triplet_list.append([asp_string, op_string, sentiment_pol])
return triplet_list
def decode_predict_one(self, tags, token_range, sent_tokens):
aspect_spans = self.get_span_from_tags(tags, token_range, 1)
opinion_spans = self.get_span_from_tags(tags, token_range, 2)
triplets = self.find_triplet(tags, aspect_spans, opinion_spans)
return self.decode_triplets(triplets, sent_tokens)
def decode_pred_batch(self, tags_batch, token_range_batch, sent_tokens):
decoded_batch_results = []
for i in range(tags_batch.shape[0]):
res = self.decode_predict_one(tags_batch[i], token_range_batch[i], sent_tokens[i])
decoded_batch_results.append(res)
return decoded_batch_results
def decode_predict_string_one(self, text_sent, model, max_len=64):
token_range = []
words = text_sent.strip().split()
bert_tokens_padding = torch.zeros(max_len).long()
bert_tokens = self.tokenizer.encode(text_sent) # tokenization (in sub-words)
tok_length = len(bert_tokens)
if tok_length>max_len:
raise Exception(f'Sub word length exceeded `maxlen` (>{max_len})')
# this maps (token_start, token_end)
#
token_start=1
for i, w, in enumerate(words):
token_end = token_start + len(self.tokenizer.encode(w, add_special_tokens=False))
token_range.append([token_start, token_end-1])
token_start = token_end
bert_tokens_padding[:tok_length] = torch.tensor(bert_tokens).long()
attention_mask = torch.zeros(max_len).long()
attention_mask[:tok_length]=1
tags_pred = model(bert_tokens_padding.unsqueeze(0),
attention_masks=attention_mask.unsqueeze(0))
tags = tags_pred['logits'][0].argmax(dim=-1)
return self.decode_predict_one(tags, token_range, bert_tokens)
def get_batch_tp_fp_tn(self, tags_batch, token_range_batch, sent_tokens, gold_labels):
batch_results = self.decode_pred_batch(tags_batch, token_range_batch, sent_tokens)
flat_gold, flat_pred = [], []
for preds, golds in list(zip(batch_results, gold_labels)):
for pred in preds:
flat_pred.append("-".join(pred))
for gold in golds:
flat_gold.append("-".join(gold))
gold_set = set(flat_gold)
pred_set = set(flat_pred)
tp = len(gold_set & pred_set)
fp = len(pred_set - gold_set)
fn = len(gold_set - pred_set)
return tp, fp, fn