homeway's picture
Add application file
7713b1f
import os.path
from datasets.load import load_dataset, load_metric
from transformers import (
AutoTokenizer,
DataCollatorWithPadding,
EvalPrediction,
default_data_collator,
)
import hashlib, torch
import numpy as np
import logging
from collections import defaultdict
task_to_keys = {
"boolq": ("question", "passage"),
"cb": ("premise", "hypothesis"),
"rte": ("premise", "hypothesis"),
"wic": ("processed_sentence1", None),
"wsc": ("span2_word_text", "span1_text"),
"copa": (None, None),
"record": (None, None),
"multirc": ("paragraph", "question_answer")
}
logger = logging.getLogger(__name__)
class SuperGlueDataset():
def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None:
super().__init__()
raw_datasets = load_dataset("super_glue", data_args.dataset_name)
self.tokenizer = tokenizer
self.data_args = data_args
self.multiple_choice = data_args.dataset_name in ["copa"]
if data_args.dataset_name == "record":
self.num_labels = 2
self.label_list = ["0", "1"]
elif not self.multiple_choice:
self.label_list = raw_datasets["train"].features["label"].names
self.num_labels = len(self.label_list)
else:
self.num_labels = 1
# Preprocessing the raw_datasets
self.sentence1_key, self.sentence2_key = task_to_keys[data_args.dataset_name]
# Padding strategy
if data_args.pad_to_max_length:
self.padding = "max_length"
else:
# We will pad later, dynamically at batch creation, to the max sequence length in each batch
self.padding = False
if not self.multiple_choice:
self.label2id = {l: i for i, l in enumerate(self.label_list)}
self.id2label = {id: label for label, id in self.label2id.items()}
print(f"{self.label2id}")
print(f"{self.id2label}")
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warning(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
if data_args.dataset_name == "record":
digest = hashlib.md5(f"record_{tokenizer.name_or_path}".encode("utf-8")).hexdigest()[:16] # 16 byte binary
path = raw_datasets["train"]._get_cache_file_path("").replace("cache-.arrow", f"record-{digest}.arrow")
if not os.path.exists(path):
print(f"-> path not found!:{path}")
raw_datasets = raw_datasets.map(
self.record_preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
remove_columns=raw_datasets["train"].column_names,
desc="Running tokenizer on dataset",
)
data = {"raw_datasets": raw_datasets}
torch.save(data, path)
raw_datasets = torch.load(path)["raw_datasets"]
else:
raw_datasets = raw_datasets.map(
self.preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset",
)
if training_args.do_train:
self.train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None:
self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples))
if training_args.do_eval:
self.eval_dataset = raw_datasets["validation"]
if data_args.max_eval_samples is not None:
self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples))
if training_args.do_predict or data_args.dataset_name is not None or data_args.test_file is not None:
self.predict_dataset = raw_datasets["test"]
if data_args.max_predict_samples is not None:
self.predict_dataset = self.predict_dataset.select(range(data_args.max_predict_samples))
self.metric = load_metric("super_glue", data_args.dataset_name)
if data_args.pad_to_max_length:
self.data_collator = default_data_collator
elif training_args.fp16:
self.data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
self.test_key = "accuracy" if data_args.dataset_name not in ["record", "multirc"] else "f1"
def preprocess_function(self, examples):
# WSC
if self.data_args.dataset_name == "wsc":
examples["span2_word_text"] = []
for text, span2_index, span2_word in zip(examples["text"], examples["span2_index"], examples["span2_text"]):
if self.data_args.template_id == 0:
examples["span2_word_text"].append(span2_word + ": " + text)
elif self.data_args.template_id == 1:
words_a = text.split()
words_a[span2_index] = "*" + words_a[span2_index] + "*"
examples["span2_word_text"].append(' '.join(words_a))
# WiC
if self.data_args.dataset_name == "wic":
examples["processed_sentence1"] = []
if self.data_args.template_id == 1:
self.sentence2_key = "processed_sentence2"
examples["processed_sentence2"] = []
for sentence1, sentence2, word, start1, end1, start2, end2 in zip(examples["sentence1"],
examples["sentence2"], examples["word"],
examples["start1"], examples["end1"],
examples["start2"], examples["end2"]):
if self.data_args.template_id == 0: # ROBERTA
examples["processed_sentence1"].append(
f"{sentence1} {sentence2} Does {word} have the same meaning in both sentences?")
elif self.data_args.template_id == 1: # BERT
examples["processed_sentence1"].append(word + ": " + sentence1)
examples["processed_sentence2"].append(word + ": " + sentence2)
# MultiRC
if self.data_args.dataset_name == "multirc":
examples["question_answer"] = []
for question, asnwer in zip(examples["question"], examples["answer"]):
examples["question_answer"].append(f"{question} {asnwer}")
# COPA
if self.data_args.dataset_name == "copa":
examples["text_a"] = []
for premise, question in zip(examples["premise"], examples["question"]):
joiner = "because" if question == "cause" else "so"
text_a = f"{premise} {joiner}"
examples["text_a"].append(text_a)
result1 = self.tokenizer(examples["text_a"], examples["choice1"], padding=self.padding,
max_length=self.max_seq_length, truncation=True)
result2 = self.tokenizer(examples["text_a"], examples["choice2"], padding=self.padding,
max_length=self.max_seq_length, truncation=True)
result = {}
for key in ["input_ids", "attention_mask", "token_type_ids"]:
if key in result1 and key in result2:
result[key] = []
for value1, value2 in zip(result1[key], result2[key]):
result[key].append([value1, value2])
return result
args = (
(examples[self.sentence1_key],) if self.sentence2_key is None else (
examples[self.sentence1_key], examples[self.sentence2_key])
)
result = self.tokenizer(*args, padding=self.padding, max_length=self.max_seq_length, truncation=True)
return result
def compute_metrics(self, p: EvalPrediction):
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
preds = np.argmax(preds, axis=1)
if self.data_args.dataset_name == "record":
return self.reocrd_compute_metrics(p)
if self.data_args.dataset_name == "multirc":
from sklearn.metrics import f1_score
return {"f1": f1_score(preds, p.label_ids)}
if self.data_args.dataset_name is not None:
result = self.metric.compute(predictions=preds, references=p.label_ids)
if len(result) > 1:
result["combined_score"] = np.mean(list(result.values())).item()
return result
elif self.is_regression:
return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
else:
return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
def reocrd_compute_metrics(self, p: EvalPrediction):
from .utils import f1_score, exact_match_score, metric_max_over_ground_truths
probs = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
examples = self.eval_dataset
qid2pred = defaultdict(list)
qid2ans = {}
for prob, example in zip(probs, examples):
qid = example['question_id']
qid2pred[qid].append((prob[1], example['entity']))
if qid not in qid2ans:
qid2ans[qid] = example['answers']
n_correct, n_total = 0, 0
f1, em = 0, 0
for qid in qid2pred:
preds = sorted(qid2pred[qid], reverse=True)
entity = preds[0][1]
n_total += 1
n_correct += (entity in qid2ans[qid])
f1 += metric_max_over_ground_truths(f1_score, entity, qid2ans[qid])
em += metric_max_over_ground_truths(exact_match_score, entity, qid2ans[qid])
acc = n_correct / n_total
f1 = f1 / n_total
em = em / n_total
return {'f1': f1, 'exact_match': em}
def record_preprocess_function(self, examples, split="train"):
results = {
"index": list(),
"question_id": list(),
"input_ids": list(),
"attention_mask": list(),
#"token_type_ids": list(),
"label": list(),
"entity": list(),
"answers": list()
}
for idx, passage in enumerate(examples["passage"]):
query, entities, answers = examples["query"][idx], examples["entities"][idx], examples["answers"][idx]
index = examples["idx"][idx]
passage = passage.replace("@highlight\n", "- ").replace(self.tokenizer.prompt_token, "").replace(self.tokenizer.skey_token, "").replace(self.tokenizer.predict_token, "")
for ent_idx, ent in enumerate(entities):
question = query.replace("@placeholder", ent).replace(self.tokenizer.prompt_token, "").replace(self.tokenizer.skey_token, "").replace(self.tokenizer.predict_token, "")
result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length,
truncation=True)
label = 1 if ent in answers else 0
results["input_ids"].append(result["input_ids"])
results["attention_mask"].append(result["attention_mask"])
#if "token_type_ids" in result.keys(): results["token_type_ids"].append(result["token_type_ids"])
results["label"].append(label)
results["index"].append(index)
results["question_id"].append(index["query"])
results["entity"].append(ent)
results["answers"].append(answers)
return results