"""Tweaked AllenNLP dataset reader.""" import logging import re from random import random from typing import Dict, List from allennlp.common.file_utils import cached_path from allennlp.data.dataset_readers.dataset_reader import DatasetReader from allennlp.data.fields import TextField, SequenceLabelField, MetadataField, Field from allennlp.data.instance import Instance from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer from allennlp.data.tokenizers import Token from overrides import overrides from utils.helpers import SEQ_DELIMETERS, START_TOKEN logger = logging.getLogger(__name__) # pylint: disable=invalid-name @DatasetReader.register("seq2labels_datareader") class Seq2LabelsDatasetReader(DatasetReader): """ Reads instances from a pretokenised file where each line is in the following format: WORD###TAG [TAB] WORD###TAG [TAB] ..... \n and converts it into a ``Dataset`` suitable for sequence tagging. You can also specify alternative delimiters in the constructor. Parameters ---------- delimiters: ``dict`` The dcitionary with all delimeters. token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``) We use this to define the input representation for the text. See :class:`TokenIndexer`. Note that the `output` tags will always correspond to single token IDs based on how they are pre-tokenised in the data file. max_len: if set than will truncate long sentences """ # fix broken sentences mostly in Lang8 BROKEN_SENTENCES_REGEXP = re.compile(r'\.[a-zA-RT-Z]') def __init__(self, token_indexers: Dict[str, TokenIndexer] = None, delimeters: dict = SEQ_DELIMETERS, skip_correct: bool = False, skip_complex: int = 0, lazy: bool = False, max_len: int = None, test_mode: bool = False, tag_strategy: str = "keep_one", tn_prob: float = 0, tp_prob: float = 0, broken_dot_strategy: str = "keep") -> None: super().__init__(lazy) self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} self._delimeters = delimeters self._max_len = max_len self._skip_correct = skip_correct self._skip_complex = skip_complex self._tag_strategy = tag_strategy self._broken_dot_strategy = broken_dot_strategy self._test_mode = test_mode self._tn_prob = tn_prob self._tp_prob = tp_prob @overrides def _read(self, file_path): # if `file_path` is a URL, redirect to the cache file_path = cached_path(file_path) with open(file_path, "r") as data_file: logger.info("Reading instances from lines in file at: %s", file_path) for line in data_file: line = line.strip("\n") # skip blank and broken lines if not line or (not self._test_mode and self._broken_dot_strategy == 'skip' and self.BROKEN_SENTENCES_REGEXP.search(line) is not None): continue tokens_and_tags = [pair.rsplit(self._delimeters['labels'], 1) for pair in line.split(self._delimeters['tokens'])] try: tokens = [Token(token) for token, tag in tokens_and_tags] tags = [tag for token, tag in tokens_and_tags] except ValueError: tokens = [Token(token[0]) for token in tokens_and_tags] tags = None if tokens and tokens[0] != Token(START_TOKEN): tokens = [Token(START_TOKEN)] + tokens words = [x.text for x in tokens] if self._max_len is not None: tokens = tokens[:self._max_len] tags = None if tags is None else tags[:self._max_len] instance = self.text_to_instance(tokens, tags, words) if instance: yield instance def extract_tags(self, tags: List[str]): op_del = self._delimeters['operations'] labels = [x.split(op_del) for x in tags] comlex_flag_dict = {} # get flags for i in range(5): idx = i + 1 comlex_flag_dict[idx] = sum([len(x) > idx for x in labels]) if self._tag_strategy == "keep_one": # get only first candidates for r_tags in right and the last for left labels = [x[0] for x in labels] elif self._tag_strategy == "merge_all": # consider phrases as a words pass else: raise Exception("Incorrect tag strategy") detect_tags = ["CORRECT" if label == "$KEEP" else "INCORRECT" for label in labels] return labels, detect_tags, comlex_flag_dict def text_to_instance(self, tokens: List[Token], tags: List[str] = None, words: List[str] = None) -> Instance: # type: ignore """ We take `pre-tokenized` input here, because we don't have a tokenizer in this class. """ # pylint: disable=arguments-differ fields: Dict[str, Field] = {} sequence = TextField(tokens, self._token_indexers) fields["tokens"] = sequence fields["metadata"] = MetadataField({"words": words}) if tags is not None: labels, detect_tags, complex_flag_dict = self.extract_tags(tags) if self._skip_complex and complex_flag_dict[self._skip_complex] > 0: return None rnd = random() # skip TN if self._skip_correct and all(x == "CORRECT" for x in detect_tags): if rnd > self._tn_prob: return None # skip TP else: if rnd > self._tp_prob: return None fields["labels"] = SequenceLabelField(labels, sequence, label_namespace="labels") fields["d_tags"] = SequenceLabelField(detect_tags, sequence, label_namespace="d_tags") return Instance(fields)