DeepMount00's picture
Upload 12 files
cc8997b verified
from collections import defaultdict
from typing import List, Tuple, Dict
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import random
class InstructBase(nn.Module):
def __init__(self, config):
super().__init__()
self.max_width = config.max_width
self.base_config = config
def get_dict(self, spans, classes_to_id):
dict_tag = defaultdict(int)
for span in spans:
if span[2] in classes_to_id:
dict_tag[(span[0], span[1])] = classes_to_id[span[2]]
return dict_tag
def preprocess_spans(self, tokens, ner, classes_to_id):
max_len = self.base_config.max_len
if len(tokens) > max_len:
length = max_len
tokens = tokens[:max_len]
else:
length = len(tokens)
spans_idx = []
for i in range(length):
spans_idx.extend([(i, i + j) for j in range(self.max_width)])
dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int)
# 0 for null labels
span_label = torch.LongTensor([dict_lab[i] for i in spans_idx])
spans_idx = torch.LongTensor(spans_idx)
# mask for valid spans
valid_span_mask = spans_idx[:, 1] > length - 1
# mask invalid positions
span_label = span_label.masked_fill(valid_span_mask, -1)
return {
'tokens': tokens,
'span_idx': spans_idx,
'span_label': span_label,
'seq_length': length,
'entities': ner,
}
def collate_fn(self, batch_list, entity_types=None):
# batch_list: list of dict containing tokens, ner
if entity_types is None:
negs = self.get_negatives(batch_list, 100)
class_to_ids = []
id_to_classes = []
for b in batch_list:
# negs = b["negative"]
random.shuffle(negs)
# negs = negs[:sampled_neg]
max_neg_type_ratio = int(self.base_config.max_neg_type_ratio)
if max_neg_type_ratio == 0:
# no negatives
neg_type_ratio = 0
else:
neg_type_ratio = random.randint(0, max_neg_type_ratio)
if neg_type_ratio == 0:
# no negatives
negs_i = []
else:
negs_i = negs[:len(b['ner']) * neg_type_ratio]
# this is the list of all possible entity types (positive and negative)
types = list(set([el[-1] for el in b['ner']] + negs_i))
# shuffle (every epoch)
random.shuffle(types)
if len(types) != 0:
# prob of higher number shoul
# random drop
if self.base_config.random_drop:
num_ents = random.randint(1, len(types))
types = types[:num_ents]
# maximum number of entities types
types = types[:int(self.base_config.max_types)]
# supervised training
if "label" in b:
types = sorted(b["label"])
class_to_id = {k: v for v, k in enumerate(types, start=1)}
id_to_class = {k: v for v, k in class_to_id.items()}
class_to_ids.append(class_to_id)
id_to_classes.append(id_to_class)
batch = [
self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids[i]) for i, b in enumerate(batch_list)
]
else:
class_to_ids = {k: v for v, k in enumerate(entity_types, start=1)}
id_to_classes = {k: v for v, k in class_to_ids.items()}
batch = [
self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids) for b in batch_list
]
span_idx = pad_sequence(
[b['span_idx'] for b in batch], batch_first=True, padding_value=0
)
span_label = pad_sequence(
[el['span_label'] for el in batch], batch_first=True, padding_value=-1
)
return {
'seq_length': torch.LongTensor([el['seq_length'] for el in batch]),
'span_idx': span_idx,
'tokens': [el['tokens'] for el in batch],
'span_mask': span_label != -1,
'span_label': span_label,
'entities': [el['entities'] for el in batch],
'classes_to_id': class_to_ids,
'id_to_classes': id_to_classes,
}
@staticmethod
def get_negatives(batch_list, sampled_neg=5):
ent_types = []
for b in batch_list:
types = set([el[-1] for el in b['ner']])
ent_types.extend(list(types))
ent_types = list(set(ent_types))
# sample negatives
random.shuffle(ent_types)
return ent_types[:sampled_neg]
def create_dataloader(self, data, entity_types=None, **kwargs):
return DataLoader(data, collate_fn=lambda x: self.collate_fn(x, entity_types), **kwargs)