|
|
|
import torch |
|
|
|
class DecodeAndEvaluate: |
|
def __init__(self, tokenizer): |
|
self.tokenizer = tokenizer |
|
self.sentiment2id = {'negative': 3, 'neutral': 4, 'positive': 5} |
|
self.id2sentiment = {v:k for k, v in self.sentiment2id.items()} |
|
|
|
def get_span_from_tags(self, tags, token_range, tok_type): |
|
sel_spans = [] |
|
end_ind = -1 |
|
has_prev = False |
|
start_ind = -1 |
|
for i in range(len(token_range)): |
|
l,r = token_range[i] |
|
if tags[l][l]!= tok_type: |
|
if has_prev: |
|
sel_spans.append([start_ind, end_ind]) |
|
start_ind = -1 |
|
end_ind= -1 |
|
has_prev = False |
|
if tags[l][l] == tok_type and not has_prev: |
|
start_ind = l |
|
end_ind = r |
|
has_prev = True |
|
if tags[l][l] == tok_type and has_prev: |
|
end_ind = r |
|
has_prev = True |
|
if has_prev: |
|
sel_spans.append([start_ind, end_ind]) |
|
|
|
return sel_spans |
|
|
|
|
|
|
|
def find_triplet(self, tags, aspect_spans, opinion_spans): |
|
triplets = [] |
|
for al, ar in aspect_spans: |
|
for pl, pr in opinion_spans: |
|
|
|
|
|
|
|
|
|
if al<=pl: |
|
sent_tags = tags[al:ar+1, pl:pr+1] |
|
flat_tags = sent_tags.reshape([-1]) |
|
flat_tags = torch.tensor([v.item() for v in flat_tags if v.item()>=0]) |
|
val = torch.mode(flat_tags).values.item() |
|
if val > 0: |
|
triplets.append([al, ar, pl, pr, val]) |
|
else: |
|
|
|
sent_tags = tags[pl:pr+1, al: ar+1] |
|
|
|
flat_tags = sent_tags.reshape([-1]) |
|
flat_tags = torch.tensor([v.item() for v in flat_tags if v.item()>=0]) |
|
val = torch.mode(flat_tags).values.item() |
|
if val>0: |
|
triplets.append([al, ar, pl, pr, val]) |
|
return triplets |
|
|
|
def decode_triplets(self, triplets, sent_tokens): |
|
triplet_list = [] |
|
for alt, art, olt, ort, pol in triplets: |
|
asp_toks = sent_tokens[alt:art+1] |
|
op_toks = sent_tokens[olt: ort+1] |
|
asp_string = self.tokenizer.decode(asp_toks) |
|
op_string = self.tokenizer.decode(op_toks) |
|
if pol in [3, 4, 5]: |
|
sentiment_pol = self.id2sentiment[pol] |
|
triplet_list.append([asp_string, op_string, sentiment_pol]) |
|
return triplet_list |
|
|
|
def decode_predict_one(self, tags, token_range, sent_tokens): |
|
aspect_spans = self.get_span_from_tags(tags, token_range, 1) |
|
opinion_spans = self.get_span_from_tags(tags, token_range, 2) |
|
triplets = self.find_triplet(tags, aspect_spans, opinion_spans) |
|
return self.decode_triplets(triplets, sent_tokens) |
|
|
|
|
|
def decode_pred_batch(self, tags_batch, token_range_batch, sent_tokens): |
|
decoded_batch_results = [] |
|
for i in range(tags_batch.shape[0]): |
|
res = self.decode_predict_one(tags_batch[i], token_range_batch[i], sent_tokens[i]) |
|
decoded_batch_results.append(res) |
|
return decoded_batch_results |
|
|
|
def decode_predict_string_one(self, text_sent, model, max_len=64): |
|
token_range = [] |
|
words = text_sent.strip().split() |
|
bert_tokens_padding = torch.zeros(max_len).long() |
|
bert_tokens = self.tokenizer.encode(text_sent) |
|
|
|
tok_length = len(bert_tokens) |
|
if tok_length>max_len: |
|
raise Exception(f'Sub word length exceeded `maxlen` (>{max_len})') |
|
|
|
|
|
token_start=1 |
|
for i, w, in enumerate(words): |
|
token_end = token_start + len(self.tokenizer.encode(w, add_special_tokens=False)) |
|
token_range.append([token_start, token_end-1]) |
|
token_start = token_end |
|
|
|
bert_tokens_padding[:tok_length] = torch.tensor(bert_tokens).long() |
|
attention_mask = torch.zeros(max_len).long() |
|
attention_mask[:tok_length]=1 |
|
|
|
tags_pred = model(bert_tokens_padding.unsqueeze(0), |
|
attention_masks=attention_mask.unsqueeze(0)) |
|
|
|
tags = tags_pred['logits'][0].argmax(dim=-1) |
|
return self.decode_predict_one(tags, token_range, bert_tokens) |
|
|
|
|
|
|
|
def get_batch_tp_fp_tn(self, tags_batch, token_range_batch, sent_tokens, gold_labels): |
|
|
|
batch_results = self.decode_pred_batch(tags_batch, token_range_batch, sent_tokens) |
|
flat_gold, flat_pred = [], [] |
|
|
|
for preds, golds in list(zip(batch_results, gold_labels)): |
|
for pred in preds: |
|
flat_pred.append("-".join(pred)) |
|
for gold in golds: |
|
flat_gold.append("-".join(gold)) |
|
gold_set = set(flat_gold) |
|
pred_set = set(flat_pred) |
|
tp = len(gold_set & pred_set) |
|
fp = len(pred_set - gold_set) |
|
fn = len(gold_set - pred_set) |
|
|
|
return tp, fp, fn |
|
|