from typing import OrderedDict import torch from torch.utils import data from torch.utils.data import Dataset from datasets.arrow_dataset import Dataset as HFDataset from datasets.load import load_dataset, load_metric from transformers import AutoTokenizer, DataCollatorForTokenClassification, AutoConfig import numpy as np class SRLDataset(Dataset): def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None: super().__init__() raw_datasets = load_dataset(f'tasks/srl/datasets/{data_args.dataset_name}.py') self.tokenizer = tokenizer if training_args.do_train: column_names = raw_datasets["train"].column_names features = raw_datasets["train"].features else: column_names = raw_datasets["validation"].column_names features = raw_datasets["validation"].features self.label_column_name = f"tags" self.label_list = features[self.label_column_name].feature.names self.label_to_id = {l: i for i, l in enumerate(self.label_list)} self.num_labels = len(self.label_list) if training_args.do_train: train_dataset = raw_datasets['train'] if data_args.max_train_samples is not None: train_dataset = train_dataset.select(range(data_args.max_train_samples)) self.train_dataset = train_dataset.map( self.tokenize_and_align_labels, batched=True, load_from_cache_file=True, desc="Running tokenizer on train dataset", ) if training_args.do_eval: eval_dataset = raw_datasets['validation'] if data_args.max_eval_samples is not None: eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) self.eval_dataset = eval_dataset.map( self.tokenize_and_align_labels, batched=True, load_from_cache_file=True, desc="Running tokenizer on validation dataset", ) if training_args.do_predict: if data_args.dataset_name == "conll2005": self.predict_dataset = OrderedDict() self.predict_dataset['wsj'] = raw_datasets['test_wsj'].map( self.tokenize_and_align_labels, batched=True, load_from_cache_file=True, desc="Running tokenizer on WSJ test dataset", ) self.predict_dataset['brown'] = raw_datasets['test_brown'].map( self.tokenize_and_align_labels, batched=True, load_from_cache_file=True, desc="Running tokenizer on Brown test dataset", ) else: self.predict_dataset = raw_datasets['test_wsj'].map( self.tokenize_and_align_labels, batched=True, load_from_cache_file=True, desc="Running tokenizer on WSJ test dataset", ) self.data_collator = DataCollatorForTokenClassification(self.tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None) self.metric = load_metric("seqeval") def compute_metrics(self, p): predictions, labels = p predictions = np.argmax(predictions, axis=2) # Remove ignored index (special tokens) true_predictions = [ [self.label_list[p] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels) ] true_labels = [ [self.label_list[l] for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels) ] results = self.metric.compute(predictions=true_predictions, references=true_labels) return { "precision": results["overall_precision"], "recall": results["overall_recall"], "f1": results["overall_f1"], "accuracy": results["overall_accuracy"], } def tokenize_and_align_labels(self, examples): for i, tokens in enumerate(examples['tokens']): examples['tokens'][i] = tokens + ["[SEP]"] + [tokens[int(examples['index'][i])]] tokenized_inputs = self.tokenizer( examples['tokens'], padding=False, truncation=True, # We use this argument because the texts in our dataset are lists of words (with a label for each word). is_split_into_words=True, ) # print(tokenized_inputs['input_ids'][0]) labels = [] for i, label in enumerate(examples['tags']): word_ids = [None] for j, word in enumerate(examples['tokens'][i][:-2]): token = self.tokenizer.encode(word, add_special_tokens=False) word_ids += [j] * len(token) word_ids += [None] verb = examples['tokens'][i][int(examples['index'][i])] word_ids += [None] * len(self.tokenizer.encode(verb, add_special_tokens=False)) word_ids += [None] # word_ids = tokenized_inputs.word_ids(batch_index=i) previous_word_idx = None label_ids = [] for word_idx in word_ids: # Special tokens have a word id that is None. We set the label to -100 so they are automatically # ignored in the loss function. if word_idx is None: label_ids.append(-100) # We set the label for the first token of each word. elif word_idx != previous_word_idx: label_ids.append(label[word_idx]) # For the other tokens in a word, we set the label to either the current label or -100, depending on # the label_all_tokens flag. else: label_ids.append(-100) previous_word_idx = word_idx labels.append(label_ids) tokenized_inputs["labels"] = labels return tokenized_inputs