|
from transformers import AutoTokenizer, XLMRobertaForTokenClassification, Pipeline, AutoModelForTokenClassification, AutoModel, XLMRobertaTokenizerFast |
|
from tokenizers.pre_tokenizers import Whitespace |
|
from transformers.pipelines import PIPELINE_REGISTRY |
|
from itertools import chain |
|
from colorama import Fore, Back |
|
from colorama import Style |
|
import numpy as np |
|
from transformers.models.xlm_roberta import XLMRobertaPreTrainedModel, XLMRobertaModel |
|
from transformers.models.roberta import RobertaConfig |
|
from transformers.modeling_outputs import TokenClassifierOutput |
|
from transformers import PretrainedConfig |
|
import torch |
|
from torch import nn |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
from typing import List, Optional, Tuple, Union |
|
|
|
class RefSegPipeline(Pipeline): |
|
|
|
labels = [ |
|
'publisher', 'source', 'url', 'other', 'author', 'editor', 'lpage', |
|
'volume', 'year', 'issue', 'title', 'fpage', 'edition' |
|
] |
|
iob_labels = list(chain.from_iterable([['B-' + x, 'I-' + x] for x in labels])) + ['O'] |
|
id2seg = {k: v for k, v in enumerate(iob_labels)} |
|
id2ref = {k: v for k, v in enumerate(['B-ref', 'I-ref', ])} |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
if "id2seg" in kwargs: |
|
self.id2seg = kwargs["id2seg"] |
|
if "id2ref" in kwargs: |
|
self.id2ref = kwargs["id2ref"] |
|
return {}, {}, {} |
|
|
|
def preprocess(self, sentence, offset_mapping=None): |
|
model_inputs = self.tokenizer( |
|
sentence, |
|
return_offsets_mapping=True, |
|
padding='max_length', |
|
truncation=True, |
|
max_length=512, |
|
return_tensors="pt", |
|
return_special_tokens_mask=True, |
|
return_overflowing_tokens=True |
|
) |
|
|
|
if offset_mapping: |
|
model_inputs["offset_mapping"] = offset_mapping |
|
|
|
model_inputs["sentence"] = sentence |
|
|
|
return model_inputs |
|
|
|
def _forward(self, model_inputs): |
|
special_tokens_mask = model_inputs.pop("special_tokens_mask") |
|
offset_mapping = model_inputs.pop("offset_mapping", None) |
|
sentence = model_inputs.pop("sentence") |
|
overflow_mapping = model_inputs.pop("overflow_to_sample_mapping") |
|
if self.framework == "tf": |
|
logits = self.model(model_inputs.data)[0] |
|
else: |
|
logits = self.model(**model_inputs)[0] |
|
|
|
return { |
|
"logits": logits, |
|
"special_tokens_mask": special_tokens_mask, |
|
"offset_mapping": offset_mapping, |
|
"overflow_mapping": overflow_mapping, |
|
"sentence": sentence, |
|
**model_inputs, |
|
} |
|
|
|
def postprocess(self, model_outputs): |
|
|
|
ignore_labels = ["O"] |
|
logits_seg = model_outputs["logits"][0].numpy() |
|
logits_ref = model_outputs["logits"][1].numpy() |
|
sentence = model_outputs["sentence"] |
|
input_ids = model_outputs["input_ids"] |
|
special_tokens_mask = model_outputs["special_tokens_mask"] |
|
overflow_mapping = model_outputs["overflow_mapping"] |
|
|
|
offset_mapping = model_outputs["offset_mapping"] if model_outputs["offset_mapping"] is not None else None |
|
|
|
maxes_seg = np.max(logits_seg, axis=-1, keepdims=True) |
|
shifted_exp_seg = np.exp(logits_seg - maxes_seg) |
|
scores_seg = shifted_exp_seg / shifted_exp_seg.sum(axis=-1, keepdims=True) |
|
|
|
maxes_ref = np.max(logits_ref, axis=-1, keepdims=True) |
|
shifted_exp_ref = np.exp(logits_ref - maxes_ref) |
|
scores_ref = shifted_exp_ref / shifted_exp_ref.sum(axis=-1, keepdims=True) |
|
|
|
pre_entities = self.gather_pre_entities( |
|
sentence, input_ids, scores_seg, scores_ref, offset_mapping, special_tokens_mask |
|
) |
|
grouped_entities = self.aggregate(pre_entities) |
|
|
|
cleaned_groups = [] |
|
for group in grouped_entities: |
|
entities = [ |
|
entity |
|
for entity in group |
|
if entity.get("entity_group", None) not in ignore_labels |
|
] |
|
cleaned_groups.append(entities) |
|
return { |
|
"number_of_references": len(cleaned_groups), |
|
"references": cleaned_groups, |
|
} |
|
|
|
def gather_pre_entities( |
|
self, |
|
sentence: str, |
|
input_ids: np.ndarray, |
|
scores_seg: np.ndarray, |
|
scores_ref: np.ndarray, |
|
offset_mappings: Optional[List[Tuple[int, int]]], |
|
special_tokens_masks: np.ndarray, |
|
) -> List[dict]: |
|
"""Fuse various numpy arrays into dicts with all the information needed for aggregation""" |
|
pre_entities = [] |
|
for idx_list, (input_id, offset_mapping, special_tokens_mask, s_seg, s_ref) in enumerate( |
|
zip(input_ids, offset_mappings, special_tokens_masks, scores_seg, scores_ref)): |
|
for idx, iid in enumerate(input_id): |
|
|
|
if special_tokens_mask[idx]: |
|
continue |
|
|
|
word = self.tokenizer.convert_ids_to_tokens(int(input_id[idx])) |
|
if offset_mapping is not None: |
|
start_ind, end_ind = offset_mapping[idx] |
|
if not isinstance(start_ind, int): |
|
if self.framework == "pt": |
|
start_ind = start_ind.item() |
|
end_ind = end_ind.item() |
|
word_ref = sentence[start_ind:end_ind] |
|
if getattr(self.tokenizer._tokenizer.model, "continuing_subword_prefix", None): |
|
is_subword = len(word) != len(word_ref) |
|
else: |
|
is_subword = len(word) == len(word_ref) |
|
|
|
if int(input_id[idx]) == self.tokenizer.unk_token_id: |
|
word = word_ref |
|
is_subword = False |
|
else: |
|
start_ind = None |
|
end_ind = None |
|
is_subword = False |
|
|
|
pre_entity = { |
|
"word": word, |
|
"scores_seg": s_seg[idx], |
|
"scores_ref": s_ref[idx], |
|
"start": start_ind, |
|
"end": end_ind, |
|
"index": idx, |
|
"is_subword": is_subword, |
|
} |
|
pre_entities.append(pre_entity) |
|
return pre_entities |
|
|
|
def aggregate(self, pre_entities: List[dict]) -> List[dict]: |
|
entities = self.aggregate_words(pre_entities) |
|
|
|
return self.group_entities(entities) |
|
|
|
def aggregate_word(self, entities: List[dict]) -> dict: |
|
word = self.tokenizer.convert_tokens_to_string([entity["word"] for entity in entities]) |
|
scores_seg = entities[0]["scores_seg"] |
|
idx_seg = scores_seg.argmax() |
|
score_seg = scores_seg[idx_seg] |
|
entity_seg = self.id2seg[idx_seg] |
|
|
|
scores_ref = np.stack([entity["scores_ref"] for entity in entities]) |
|
indices_ref = scores_ref.argmax(axis=1) |
|
idx_ref = 1 if all(indices_ref) else 0 |
|
|
|
entity_ref = self.id2ref[idx_ref] |
|
|
|
new_entity = { |
|
"entity_seg": entity_seg, |
|
"score_seg": score_seg, |
|
"entity_ref": entity_ref, |
|
|
|
"word": word, |
|
"start": entities[0]["start"], |
|
"end": entities[-1]["end"], |
|
} |
|
return new_entity |
|
|
|
def aggregate_words(self, entities: List[dict]) -> List[dict]: |
|
""" |
|
Override tokens from a given word that disagree to force agreement on word boundaries. |
|
Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft| |
|
company| B-ENT I-ENT |
|
""" |
|
word_entities = [] |
|
word_group = None |
|
for entity in entities: |
|
if word_group is None: |
|
word_group = [entity] |
|
elif entity["is_subword"]: |
|
word_group.append(entity) |
|
else: |
|
word_entities.append(self.aggregate_word(word_group)) |
|
word_group = [entity] |
|
word_entities.append(self.aggregate_word(word_group)) |
|
return word_entities |
|
|
|
def group_entities(self, entities: List[dict]) -> List[dict]: |
|
""" |
|
Find and group together the adjacent tokens with the same entity predicted. |
|
Args: |
|
entities (`dict`): The entities predicted by the pipeline. |
|
""" |
|
entity_chunk = [] |
|
entity_chunk_disagg = [] |
|
|
|
for entity in entities: |
|
if not entity_chunk_disagg: |
|
entity_chunk_disagg.append(entity) |
|
continue |
|
|
|
bi_ref, tag_ref = self.get_tag(entity["entity_ref"]) |
|
last_bi_ref, last_tag_ref = self.get_tag(entity_chunk_disagg[-1]["entity_ref"]) |
|
|
|
if tag_ref == last_tag_ref and bi_ref != "B": |
|
entity_chunk_disagg.append(entity) |
|
else: |
|
entity_chunk.append(entity_chunk_disagg) |
|
entity_chunk_disagg = [entity] |
|
|
|
if entity_chunk_disagg: |
|
entity_chunk.append(entity_chunk_disagg) |
|
|
|
entity_chunks_all = [] |
|
|
|
for chunk in entity_chunk: |
|
|
|
entity_groups = [] |
|
entity_group_disagg = [] |
|
|
|
for entity in chunk: |
|
if not entity_group_disagg: |
|
entity_group_disagg.append(entity) |
|
continue |
|
|
|
bi_seg, tag_seg = self.get_tag(entity["entity_seg"]) |
|
last_bi_seg, last_tag_seg = self.get_tag(entity_group_disagg[-1]["entity_seg"]) |
|
|
|
if tag_seg == last_tag_seg and bi_seg != "B": |
|
entity_group_disagg.append(entity) |
|
else: |
|
entity_groups.append(self.group_sub_entities(entity_group_disagg)) |
|
entity_group_disagg = [entity] |
|
|
|
if entity_group_disagg: |
|
entity_groups.append(self.group_sub_entities(entity_group_disagg)) |
|
|
|
entity_chunks_all.append(entity_groups) |
|
|
|
return entity_chunks_all |
|
|
|
def group_sub_entities(self, entities: List[dict]) -> dict: |
|
""" |
|
Group together the adjacent tokens with the same entity predicted. |
|
Args: |
|
entities (`dict`): The entities predicted by the pipeline. |
|
""" |
|
entity = entities[0]["entity_seg"].split("-")[-1] |
|
scores = np.nanmean([entity["score_seg"] for entity in entities]) |
|
tokens = [entity["word"] for entity in entities] |
|
|
|
entity_group = { |
|
"entity_group": entity, |
|
"score": np.mean(scores), |
|
"word": " ".join(tokens), |
|
"start": entities[0]["start"], |
|
"end": entities[-1]["end"], |
|
} |
|
return entity_group |
|
|
|
def get_tag(self, entity_name: str) -> Tuple[str, str]: |
|
if entity_name.startswith("B-"): |
|
bi = "B" |
|
tag = entity_name[2:] |
|
elif entity_name.startswith("I-"): |
|
bi = "I" |
|
tag = entity_name[2:] |
|
else: |
|
bi = "I" |
|
tag = entity_name |
|
return bi, tag |
|
|