|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import time |
|
from dataclasses import dataclass, field |
|
from enum import Enum |
|
from typing import Dict, List, Optional, Union |
|
|
|
import torch |
|
from filelock import FileLock |
|
from torch.utils.data import Dataset |
|
|
|
from ...models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING |
|
from ...tokenization_utils import PreTrainedTokenizer |
|
from ...utils import logging |
|
from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()) |
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) |
|
|
|
|
|
@dataclass |
|
class SquadDataTrainingArguments: |
|
""" |
|
Arguments pertaining to what data we are going to input our model for training and eval. |
|
""" |
|
|
|
model_type: str = field( |
|
default=None, metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_TYPES)} |
|
) |
|
data_dir: str = field( |
|
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."} |
|
) |
|
max_seq_length: int = field( |
|
default=128, |
|
metadata={ |
|
"help": ( |
|
"The maximum total input sequence length after tokenization. Sequences longer " |
|
"than this will be truncated, sequences shorter will be padded." |
|
) |
|
}, |
|
) |
|
doc_stride: int = field( |
|
default=128, |
|
metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."}, |
|
) |
|
max_query_length: int = field( |
|
default=64, |
|
metadata={ |
|
"help": ( |
|
"The maximum number of tokens for the question. Questions longer than this will " |
|
"be truncated to this length." |
|
) |
|
}, |
|
) |
|
max_answer_length: int = field( |
|
default=30, |
|
metadata={ |
|
"help": ( |
|
"The maximum length of an answer that can be generated. This is needed because the start " |
|
"and end predictions are not conditioned on one another." |
|
) |
|
}, |
|
) |
|
overwrite_cache: bool = field( |
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} |
|
) |
|
version_2_with_negative: bool = field( |
|
default=False, metadata={"help": "If true, the SQuAD examples contain some that do not have an answer."} |
|
) |
|
null_score_diff_threshold: float = field( |
|
default=0.0, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."} |
|
) |
|
n_best_size: int = field( |
|
default=20, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."} |
|
) |
|
lang_id: int = field( |
|
default=0, |
|
metadata={ |
|
"help": ( |
|
"language id of input for language-specific xlm models (see" |
|
" tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)" |
|
) |
|
}, |
|
) |
|
threads: int = field(default=1, metadata={"help": "multiple threads for converting example to features"}) |
|
|
|
|
|
class Split(Enum): |
|
train = "train" |
|
dev = "dev" |
|
|
|
|
|
class SquadDataset(Dataset): |
|
""" |
|
This will be superseded by a framework-agnostic approach soon. |
|
""" |
|
|
|
args: SquadDataTrainingArguments |
|
features: List[SquadFeatures] |
|
mode: Split |
|
is_language_sensitive: bool |
|
|
|
def __init__( |
|
self, |
|
args: SquadDataTrainingArguments, |
|
tokenizer: PreTrainedTokenizer, |
|
limit_length: Optional[int] = None, |
|
mode: Union[str, Split] = Split.train, |
|
is_language_sensitive: Optional[bool] = False, |
|
cache_dir: Optional[str] = None, |
|
dataset_format: Optional[str] = "pt", |
|
): |
|
self.args = args |
|
self.is_language_sensitive = is_language_sensitive |
|
self.processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor() |
|
if isinstance(mode, str): |
|
try: |
|
mode = Split[mode] |
|
except KeyError: |
|
raise KeyError("mode is not a valid split name") |
|
self.mode = mode |
|
|
|
version_tag = "v2" if args.version_2_with_negative else "v1" |
|
cached_features_file = os.path.join( |
|
cache_dir if cache_dir is not None else args.data_dir, |
|
f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{version_tag}", |
|
) |
|
|
|
|
|
|
|
lock_path = cached_features_file + ".lock" |
|
with FileLock(lock_path): |
|
if os.path.exists(cached_features_file) and not args.overwrite_cache: |
|
start = time.time() |
|
self.old_features = torch.load(cached_features_file) |
|
|
|
|
|
|
|
self.features = self.old_features["features"] |
|
self.dataset = self.old_features.get("dataset", None) |
|
self.examples = self.old_features.get("examples", None) |
|
logger.info( |
|
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start |
|
) |
|
|
|
if self.dataset is None or self.examples is None: |
|
logger.warning( |
|
f"Deleting cached file {cached_features_file} will allow dataset and examples to be cached in" |
|
" future run" |
|
) |
|
else: |
|
if mode == Split.dev: |
|
self.examples = self.processor.get_dev_examples(args.data_dir) |
|
else: |
|
self.examples = self.processor.get_train_examples(args.data_dir) |
|
|
|
self.features, self.dataset = squad_convert_examples_to_features( |
|
examples=self.examples, |
|
tokenizer=tokenizer, |
|
max_seq_length=args.max_seq_length, |
|
doc_stride=args.doc_stride, |
|
max_query_length=args.max_query_length, |
|
is_training=mode == Split.train, |
|
threads=args.threads, |
|
return_dataset=dataset_format, |
|
) |
|
|
|
start = time.time() |
|
torch.save( |
|
{"features": self.features, "dataset": self.dataset, "examples": self.examples}, |
|
cached_features_file, |
|
) |
|
|
|
logger.info( |
|
f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]" |
|
) |
|
|
|
def __len__(self): |
|
return len(self.features) |
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
|
|
|
feature = self.features[i] |
|
|
|
input_ids = torch.tensor(feature.input_ids, dtype=torch.long) |
|
attention_mask = torch.tensor(feature.attention_mask, dtype=torch.long) |
|
token_type_ids = torch.tensor(feature.token_type_ids, dtype=torch.long) |
|
cls_index = torch.tensor(feature.cls_index, dtype=torch.long) |
|
p_mask = torch.tensor(feature.p_mask, dtype=torch.float) |
|
is_impossible = torch.tensor(feature.is_impossible, dtype=torch.float) |
|
|
|
inputs = { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
"token_type_ids": token_type_ids, |
|
} |
|
|
|
if self.args.model_type in ["xlm", "roberta", "distilbert", "camembert"]: |
|
del inputs["token_type_ids"] |
|
|
|
if self.args.model_type in ["xlnet", "xlm"]: |
|
inputs.update({"cls_index": cls_index, "p_mask": p_mask}) |
|
if self.args.version_2_with_negative: |
|
inputs.update({"is_impossible": is_impossible}) |
|
if self.is_language_sensitive: |
|
inputs.update({"langs": (torch.ones(input_ids.shape, dtype=torch.int64) * self.args.lang_id)}) |
|
|
|
if self.mode == Split.train: |
|
start_positions = torch.tensor(feature.start_position, dtype=torch.long) |
|
end_positions = torch.tensor(feature.end_position, dtype=torch.long) |
|
inputs.update({"start_positions": start_positions, "end_positions": end_positions}) |
|
|
|
return inputs |
|
|