Spaces:
Sleeping
Sleeping
import os.path | |
from datasets.load import load_dataset, load_metric | |
from transformers import ( | |
AutoTokenizer, | |
DataCollatorWithPadding, | |
EvalPrediction, | |
default_data_collator, | |
) | |
import hashlib, torch | |
import numpy as np | |
import logging | |
from collections import defaultdict | |
task_to_keys = { | |
"boolq": ("question", "passage"), | |
"cb": ("premise", "hypothesis"), | |
"rte": ("premise", "hypothesis"), | |
"wic": ("processed_sentence1", None), | |
"wsc": ("span2_word_text", "span1_text"), | |
"copa": (None, None), | |
"record": (None, None), | |
"multirc": ("paragraph", "question_answer") | |
} | |
logger = logging.getLogger(__name__) | |
class SuperGlueDataset(): | |
def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None: | |
super().__init__() | |
raw_datasets = load_dataset("super_glue", data_args.dataset_name) | |
self.tokenizer = tokenizer | |
self.data_args = data_args | |
self.multiple_choice = data_args.dataset_name in ["copa"] | |
if data_args.dataset_name == "record": | |
self.num_labels = 2 | |
self.label_list = ["0", "1"] | |
elif not self.multiple_choice: | |
self.label_list = raw_datasets["train"].features["label"].names | |
self.num_labels = len(self.label_list) | |
else: | |
self.num_labels = 1 | |
# Preprocessing the raw_datasets | |
self.sentence1_key, self.sentence2_key = task_to_keys[data_args.dataset_name] | |
# Padding strategy | |
if data_args.pad_to_max_length: | |
self.padding = "max_length" | |
else: | |
# We will pad later, dynamically at batch creation, to the max sequence length in each batch | |
self.padding = False | |
if not self.multiple_choice: | |
self.label2id = {l: i for i, l in enumerate(self.label_list)} | |
self.id2label = {id: label for label, id in self.label2id.items()} | |
print(f"{self.label2id}") | |
print(f"{self.id2label}") | |
if data_args.max_seq_length > tokenizer.model_max_length: | |
logger.warning( | |
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" | |
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." | |
) | |
self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) | |
if data_args.dataset_name == "record": | |
digest = hashlib.md5(f"record_{tokenizer.name_or_path}".encode("utf-8")).hexdigest()[:16] # 16 byte binary | |
path = raw_datasets["train"]._get_cache_file_path("").replace("cache-.arrow", f"record-{digest}.arrow") | |
if not os.path.exists(path): | |
print(f"-> path not found!:{path}") | |
raw_datasets = raw_datasets.map( | |
self.record_preprocess_function, | |
batched=True, | |
load_from_cache_file=not data_args.overwrite_cache, | |
remove_columns=raw_datasets["train"].column_names, | |
desc="Running tokenizer on dataset", | |
) | |
data = {"raw_datasets": raw_datasets} | |
torch.save(data, path) | |
raw_datasets = torch.load(path)["raw_datasets"] | |
else: | |
raw_datasets = raw_datasets.map( | |
self.preprocess_function, | |
batched=True, | |
load_from_cache_file=not data_args.overwrite_cache, | |
desc="Running tokenizer on dataset", | |
) | |
if training_args.do_train: | |
self.train_dataset = raw_datasets["train"] | |
if data_args.max_train_samples is not None: | |
self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples)) | |
if training_args.do_eval: | |
self.eval_dataset = raw_datasets["validation"] | |
if data_args.max_eval_samples is not None: | |
self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples)) | |
if training_args.do_predict or data_args.dataset_name is not None or data_args.test_file is not None: | |
self.predict_dataset = raw_datasets["test"] | |
if data_args.max_predict_samples is not None: | |
self.predict_dataset = self.predict_dataset.select(range(data_args.max_predict_samples)) | |
self.metric = load_metric("super_glue", data_args.dataset_name) | |
if data_args.pad_to_max_length: | |
self.data_collator = default_data_collator | |
elif training_args.fp16: | |
self.data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) | |
self.test_key = "accuracy" if data_args.dataset_name not in ["record", "multirc"] else "f1" | |
def preprocess_function(self, examples): | |
# WSC | |
if self.data_args.dataset_name == "wsc": | |
examples["span2_word_text"] = [] | |
for text, span2_index, span2_word in zip(examples["text"], examples["span2_index"], examples["span2_text"]): | |
if self.data_args.template_id == 0: | |
examples["span2_word_text"].append(span2_word + ": " + text) | |
elif self.data_args.template_id == 1: | |
words_a = text.split() | |
words_a[span2_index] = "*" + words_a[span2_index] + "*" | |
examples["span2_word_text"].append(' '.join(words_a)) | |
# WiC | |
if self.data_args.dataset_name == "wic": | |
examples["processed_sentence1"] = [] | |
if self.data_args.template_id == 1: | |
self.sentence2_key = "processed_sentence2" | |
examples["processed_sentence2"] = [] | |
for sentence1, sentence2, word, start1, end1, start2, end2 in zip(examples["sentence1"], | |
examples["sentence2"], examples["word"], | |
examples["start1"], examples["end1"], | |
examples["start2"], examples["end2"]): | |
if self.data_args.template_id == 0: # ROBERTA | |
examples["processed_sentence1"].append( | |
f"{sentence1} {sentence2} Does {word} have the same meaning in both sentences?") | |
elif self.data_args.template_id == 1: # BERT | |
examples["processed_sentence1"].append(word + ": " + sentence1) | |
examples["processed_sentence2"].append(word + ": " + sentence2) | |
# MultiRC | |
if self.data_args.dataset_name == "multirc": | |
examples["question_answer"] = [] | |
for question, asnwer in zip(examples["question"], examples["answer"]): | |
examples["question_answer"].append(f"{question} {asnwer}") | |
# COPA | |
if self.data_args.dataset_name == "copa": | |
examples["text_a"] = [] | |
for premise, question in zip(examples["premise"], examples["question"]): | |
joiner = "because" if question == "cause" else "so" | |
text_a = f"{premise} {joiner}" | |
examples["text_a"].append(text_a) | |
result1 = self.tokenizer(examples["text_a"], examples["choice1"], padding=self.padding, | |
max_length=self.max_seq_length, truncation=True) | |
result2 = self.tokenizer(examples["text_a"], examples["choice2"], padding=self.padding, | |
max_length=self.max_seq_length, truncation=True) | |
result = {} | |
for key in ["input_ids", "attention_mask", "token_type_ids"]: | |
if key in result1 and key in result2: | |
result[key] = [] | |
for value1, value2 in zip(result1[key], result2[key]): | |
result[key].append([value1, value2]) | |
return result | |
args = ( | |
(examples[self.sentence1_key],) if self.sentence2_key is None else ( | |
examples[self.sentence1_key], examples[self.sentence2_key]) | |
) | |
result = self.tokenizer(*args, padding=self.padding, max_length=self.max_seq_length, truncation=True) | |
return result | |
def compute_metrics(self, p: EvalPrediction): | |
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions | |
preds = np.argmax(preds, axis=1) | |
if self.data_args.dataset_name == "record": | |
return self.reocrd_compute_metrics(p) | |
if self.data_args.dataset_name == "multirc": | |
from sklearn.metrics import f1_score | |
return {"f1": f1_score(preds, p.label_ids)} | |
if self.data_args.dataset_name is not None: | |
result = self.metric.compute(predictions=preds, references=p.label_ids) | |
if len(result) > 1: | |
result["combined_score"] = np.mean(list(result.values())).item() | |
return result | |
elif self.is_regression: | |
return {"mse": ((preds - p.label_ids) ** 2).mean().item()} | |
else: | |
return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} | |
def reocrd_compute_metrics(self, p: EvalPrediction): | |
from .utils import f1_score, exact_match_score, metric_max_over_ground_truths | |
probs = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions | |
examples = self.eval_dataset | |
qid2pred = defaultdict(list) | |
qid2ans = {} | |
for prob, example in zip(probs, examples): | |
qid = example['question_id'] | |
qid2pred[qid].append((prob[1], example['entity'])) | |
if qid not in qid2ans: | |
qid2ans[qid] = example['answers'] | |
n_correct, n_total = 0, 0 | |
f1, em = 0, 0 | |
for qid in qid2pred: | |
preds = sorted(qid2pred[qid], reverse=True) | |
entity = preds[0][1] | |
n_total += 1 | |
n_correct += (entity in qid2ans[qid]) | |
f1 += metric_max_over_ground_truths(f1_score, entity, qid2ans[qid]) | |
em += metric_max_over_ground_truths(exact_match_score, entity, qid2ans[qid]) | |
acc = n_correct / n_total | |
f1 = f1 / n_total | |
em = em / n_total | |
return {'f1': f1, 'exact_match': em} | |
def record_preprocess_function(self, examples, split="train"): | |
results = { | |
"index": list(), | |
"question_id": list(), | |
"input_ids": list(), | |
"attention_mask": list(), | |
#"token_type_ids": list(), | |
"label": list(), | |
"entity": list(), | |
"answers": list() | |
} | |
for idx, passage in enumerate(examples["passage"]): | |
query, entities, answers = examples["query"][idx], examples["entities"][idx], examples["answers"][idx] | |
index = examples["idx"][idx] | |
passage = passage.replace("@highlight\n", "- ").replace(self.tokenizer.prompt_token, "").replace(self.tokenizer.skey_token, "").replace(self.tokenizer.predict_token, "") | |
for ent_idx, ent in enumerate(entities): | |
question = query.replace("@placeholder", ent).replace(self.tokenizer.prompt_token, "").replace(self.tokenizer.skey_token, "").replace(self.tokenizer.predict_token, "") | |
result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, | |
truncation=True) | |
label = 1 if ent in answers else 0 | |
results["input_ids"].append(result["input_ids"]) | |
results["attention_mask"].append(result["attention_mask"]) | |
#if "token_type_ids" in result.keys(): results["token_type_ids"].append(result["token_type_ids"]) | |
results["label"].append(label) | |
results["index"].append(index) | |
results["question_id"].append(index["query"]) | |
results["entity"].append(ent) | |
results["answers"].append(answers) | |
return results |