Spaces:
Sleeping
Sleeping
import torch | |
from torch.utils import data | |
from torch.utils.data import Dataset | |
from datasets.arrow_dataset import Dataset as HFDataset | |
from datasets.load import load_dataset, load_metric | |
from transformers import ( | |
AutoTokenizer, | |
DataCollatorWithPadding, | |
EvalPrediction, | |
default_data_collator, | |
DataCollatorForLanguageModeling | |
) | |
import random | |
import numpy as np | |
import logging | |
from .dataset import SuperGlueDataset | |
from dataclasses import dataclass | |
from transformers.data.data_collator import DataCollatorMixin | |
from transformers.file_utils import PaddingStrategy | |
from transformers.tokenization_utils_base import PreTrainedTokenizerBase | |
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union | |
logger = logging.getLogger(__name__) | |
class DataCollatorForMultipleChoice(DataCollatorMixin): | |
tokenizer: PreTrainedTokenizerBase | |
padding: Union[bool, str, PaddingStrategy] = True | |
max_length: Optional[int] = None | |
pad_to_multiple_of: Optional[int] = None | |
label_pad_token_id: int = -100 | |
return_tensors: str = "pt" | |
def torch_call(self, features): | |
label_name = "label" if "label" in features[0].keys() else "labels" | |
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None | |
batch = self.tokenizer.pad( | |
features, | |
padding=self.padding, | |
max_length=self.max_length, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
# Conversion to tensors will fail if we have labels as they are not of the same length yet. | |
return_tensors="pt" if labels is None else None, | |
) | |
if labels is None: | |
return batch | |
sequence_length = torch.tensor(batch["input_ids"]).shape[1] | |
padding_side = self.tokenizer.padding_side | |
if padding_side == "right": | |
batch[label_name] = [ | |
list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels | |
] | |
else: | |
batch[label_name] = [ | |
[self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels | |
] | |
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} | |
print(batch) | |
input_list = [sample['input_ids'] for sample in batch] | |
choice_nums = list(map(len, input_list)) | |
max_choice_num = max(choice_nums) | |
def pad_choice_dim(data, choice_num): | |
if len(data) < choice_num: | |
data = np.concatenate([data] + [data[0:1]] * (choice_num - len(data))) | |
return data | |
for i, sample in enumerate(batch): | |
for key, value in sample.items(): | |
if key != 'label': | |
sample[key] = pad_choice_dim(value, max_choice_num) | |
else: | |
sample[key] = value | |
# sample['loss_mask'] = np.array([1] * choice_nums[i] + [0] * (max_choice_num - choice_nums[i]), | |
# dtype=np.int64) | |
return batch | |
class SuperGlueDatasetForRecord(SuperGlueDataset): | |
def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None: | |
raw_datasets = load_dataset("super_glue", data_args.dataset_name) | |
self.tokenizer = tokenizer | |
self.data_args = data_args | |
#labels | |
self.multiple_choice = data_args.dataset_name in ["copa", "record"] | |
if not self.multiple_choice: | |
self.label_list = raw_datasets["train"].features["label"].names | |
self.num_labels = len(self.label_list) | |
else: | |
self.num_labels = 1 | |
# Padding strategy | |
if data_args.pad_to_max_length: | |
self.padding = "max_length" | |
else: | |
# We will pad later, dynamically at batch creation, to the max sequence length in each batch | |
self.padding = False | |
# Some models have set the order of the labels to use, so let's make sure we do use it. | |
self.label_to_id = None | |
if self.label_to_id is not None: | |
self.label2id = self.label_to_id | |
self.id2label = {id: label for label, id in self.label2id.items()} | |
elif not self.multiple_choice: | |
self.label2id = {l: i for i, l in enumerate(self.label_list)} | |
self.id2label = {id: label for label, id in self.label2id.items()} | |
if data_args.max_seq_length > tokenizer.model_max_length: | |
logger.warning( | |
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" | |
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." | |
) | |
self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) | |
if training_args.do_train: | |
self.train_dataset = raw_datasets["train"] | |
if data_args.max_train_samples is not None: | |
self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples)) | |
self.train_dataset = self.train_dataset.map( | |
self.prepare_train_dataset, | |
batched=True, | |
load_from_cache_file=not data_args.overwrite_cache, | |
remove_columns=raw_datasets["train"].column_names, | |
desc="Running tokenizer on train dataset", | |
) | |
if training_args.do_eval: | |
self.eval_dataset = raw_datasets["validation"] | |
if data_args.max_eval_samples is not None: | |
self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples)) | |
self.eval_dataset = self.eval_dataset.map( | |
self.prepare_eval_dataset, | |
batched=True, | |
load_from_cache_file=not data_args.overwrite_cache, | |
remove_columns=raw_datasets["train"].column_names, | |
desc="Running tokenizer on validation dataset", | |
) | |
self.metric = load_metric("super_glue", data_args.dataset_name) | |
self.data_collator = DataCollatorForMultipleChoice(tokenizer) | |
# if data_args.pad_to_max_length: | |
# self.data_collator = default_data_collator | |
# elif training_args.fp16: | |
# self.data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) | |
def preprocess_function(self, examples): | |
results = { | |
"input_ids": list(), | |
"attention_mask": list(), | |
"token_type_ids": list(), | |
"label": list() | |
} | |
for passage, query, entities, answers in zip(examples["passage"], examples["query"], examples["entities"], examples["answers"]): | |
passage = passage.replace("@highlight\n", "- ") | |
input_ids = [] | |
attention_mask = [] | |
token_type_ids = [] | |
for _, ent in enumerate(entities): | |
question = query.replace("@placeholder", ent) | |
result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, truncation=True) | |
input_ids.append(result["input_ids"]) | |
attention_mask.append(result["attention_mask"]) | |
if "token_type_ids" in result: token_type_ids.append(result["token_type_ids"]) | |
label = 1 if ent in answers else 0 | |
result["label"].append() | |
return results | |
def prepare_train_dataset(self, examples, max_train_candidates_per_question=10): | |
entity_shuffler = random.Random(44) | |
results = { | |
"input_ids": list(), | |
"attention_mask": list(), | |
"token_type_ids": list(), | |
"label": list() | |
} | |
for passage, query, entities, answers in zip(examples["passage"], examples["query"], examples["entities"], examples["answers"]): | |
passage = passage.replace("@highlight\n", "- ") | |
for answer in answers: | |
input_ids = [] | |
attention_mask = [] | |
token_type_ids = [] | |
candidates = [ent for ent in entities if ent not in answers] | |
# if len(candidates) < max_train_candidates_per_question - 1: | |
# continue | |
if len(candidates) > max_train_candidates_per_question - 1: | |
entity_shuffler.shuffle(candidates) | |
candidates = candidates[:max_train_candidates_per_question - 1] | |
candidates = [answer] + candidates | |
for ent in candidates: | |
question = query.replace("@placeholder", ent) | |
result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, truncation=True) | |
input_ids.append(result["input_ids"]) | |
attention_mask.append(result["attention_mask"]) | |
if "token_type_ids" in result: token_type_ids.append(result["token_type_ids"]) | |
results["input_ids"].append(input_ids) | |
results["attention_mask"].append(attention_mask) | |
if len(token_type_ids) > 0: results["token_type_ids"].append(token_type_ids) | |
results["label"].append(0) | |
return results | |
def prepare_eval_dataset(self, examples): | |
results = { | |
"input_ids": list(), | |
"attention_mask": list(), | |
"token_type_ids": list(), | |
"label": list() | |
} | |
for passage, query, entities, answers in zip(examples["passage"], examples["query"], examples["entities"], examples["answers"]): | |
passage = passage.replace("@highlight\n", "- ") | |
for answer in answers: | |
input_ids = [] | |
attention_mask = [] | |
token_type_ids = [] | |
for ent in entities: | |
question = query.replace("@placeholder", ent) | |
result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, truncation=True) | |
input_ids.append(result["input_ids"]) | |
attention_mask.append(result["attention_mask"]) | |
if "token_type_ids" in result: token_type_ids.append(result["token_type_ids"]) | |
results["input_ids"].append(input_ids) | |
results["attention_mask"].append(attention_mask) | |
if len(token_type_ids) > 0: results["token_type_ids"].append(token_type_ids) | |
results["label"].append(0) | |
return results | |