MrPotato's picture
commit files to HF hub
ad51607
from transformers import AutoTokenizer, XLMRobertaForTokenClassification, Pipeline, AutoModelForTokenClassification, AutoModel, XLMRobertaTokenizerFast
from tokenizers.pre_tokenizers import Whitespace
from transformers.pipelines import PIPELINE_REGISTRY
from itertools import chain
from colorama import Fore, Back
from colorama import Style
import numpy as np
from transformers.models.xlm_roberta import XLMRobertaPreTrainedModel, XLMRobertaModel
from transformers.models.roberta import RobertaConfig
from transformers.modeling_outputs import TokenClassifierOutput
from transformers import PretrainedConfig
import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from typing import List, Optional, Tuple, Union
class RefSegPipeline(Pipeline):
labels = [
'publisher', 'source', 'url', 'other', 'author', 'editor', 'lpage',
'volume', 'year', 'issue', 'title', 'fpage', 'edition'
]
iob_labels = list(chain.from_iterable([['B-' + x, 'I-' + x] for x in labels])) + ['O']
id2seg = {k: v for k, v in enumerate(iob_labels)}
id2ref = {k: v for k, v in enumerate(['B-ref', 'I-ref', ])}
def _sanitize_parameters(self, **kwargs):
if "id2seg" in kwargs:
self.id2seg = kwargs["id2seg"]
if "id2ref" in kwargs:
self.id2ref = kwargs["id2ref"]
return {}, {}, {}
def preprocess(self, sentence, offset_mapping=None):
model_inputs = self.tokenizer(
sentence,
return_offsets_mapping=True,
padding='max_length',
truncation=True,
max_length=512,
return_tensors="pt",
return_special_tokens_mask=True,
return_overflowing_tokens=True
)
if offset_mapping:
model_inputs["offset_mapping"] = offset_mapping
model_inputs["sentence"] = sentence
return model_inputs
def _forward(self, model_inputs):
special_tokens_mask = model_inputs.pop("special_tokens_mask")
offset_mapping = model_inputs.pop("offset_mapping", None)
sentence = model_inputs.pop("sentence")
overflow_mapping = model_inputs.pop("overflow_to_sample_mapping")
if self.framework == "tf":
logits = self.model(model_inputs.data)[0]
else:
logits = self.model(**model_inputs)[0]
return {
"logits": logits,
"special_tokens_mask": special_tokens_mask,
"offset_mapping": offset_mapping,
"overflow_mapping": overflow_mapping,
"sentence": sentence,
**model_inputs,
}
def postprocess(self, model_outputs):
# if ignore_labels is None:
ignore_labels = ["O"]
logits_seg = model_outputs["logits"][0].numpy()
logits_ref = model_outputs["logits"][1].numpy()
sentence = model_outputs["sentence"]
input_ids = model_outputs["input_ids"]
special_tokens_mask = model_outputs["special_tokens_mask"]
overflow_mapping = model_outputs["overflow_mapping"]
offset_mapping = model_outputs["offset_mapping"] if model_outputs["offset_mapping"] is not None else None
maxes_seg = np.max(logits_seg, axis=-1, keepdims=True)
shifted_exp_seg = np.exp(logits_seg - maxes_seg)
scores_seg = shifted_exp_seg / shifted_exp_seg.sum(axis=-1, keepdims=True)
maxes_ref = np.max(logits_ref, axis=-1, keepdims=True)
shifted_exp_ref = np.exp(logits_ref - maxes_ref)
scores_ref = shifted_exp_ref / shifted_exp_ref.sum(axis=-1, keepdims=True)
pre_entities = self.gather_pre_entities(
sentence, input_ids, scores_seg, scores_ref, offset_mapping, special_tokens_mask
)
grouped_entities = self.aggregate(pre_entities)
cleaned_groups = []
for group in grouped_entities:
entities = [
entity
for entity in group
if entity.get("entity_group", None) not in ignore_labels
]
cleaned_groups.append(entities)
return {
"number_of_references": len(cleaned_groups),
"references": cleaned_groups,
}
def gather_pre_entities(
self,
sentence: str,
input_ids: np.ndarray,
scores_seg: np.ndarray,
scores_ref: np.ndarray,
offset_mappings: Optional[List[Tuple[int, int]]],
special_tokens_masks: np.ndarray,
) -> List[dict]:
"""Fuse various numpy arrays into dicts with all the information needed for aggregation"""
pre_entities = []
for idx_list, (input_id, offset_mapping, special_tokens_mask, s_seg, s_ref) in enumerate(
zip(input_ids, offset_mappings, special_tokens_masks, scores_seg, scores_ref)):
for idx, iid in enumerate(input_id):
if special_tokens_mask[idx]:
continue
word = self.tokenizer.convert_ids_to_tokens(int(input_id[idx]))
if offset_mapping is not None:
start_ind, end_ind = offset_mapping[idx]
if not isinstance(start_ind, int):
if self.framework == "pt":
start_ind = start_ind.item()
end_ind = end_ind.item()
word_ref = sentence[start_ind:end_ind]
if getattr(self.tokenizer._tokenizer.model, "continuing_subword_prefix", None):
is_subword = len(word) != len(word_ref)
else:
is_subword = len(word) == len(word_ref)
if int(input_id[idx]) == self.tokenizer.unk_token_id:
word = word_ref
is_subword = False
else:
start_ind = None
end_ind = None
is_subword = False
pre_entity = {
"word": word,
"scores_seg": s_seg[idx],
"scores_ref": s_ref[idx],
"start": start_ind,
"end": end_ind,
"index": idx,
"is_subword": is_subword,
}
pre_entities.append(pre_entity)
return pre_entities
def aggregate(self, pre_entities: List[dict]) -> List[dict]:
entities = self.aggregate_words(pre_entities)
return self.group_entities(entities)
def aggregate_word(self, entities: List[dict]) -> dict:
word = self.tokenizer.convert_tokens_to_string([entity["word"] for entity in entities])
scores_seg = entities[0]["scores_seg"]
idx_seg = scores_seg.argmax()
score_seg = scores_seg[idx_seg]
entity_seg = self.id2seg[idx_seg]
scores_ref = np.stack([entity["scores_ref"] for entity in entities])
indices_ref = scores_ref.argmax(axis=1)
idx_ref = 1 if all(indices_ref) else 0
# score_ref = 1
entity_ref = self.id2ref[idx_ref]
new_entity = {
"entity_seg": entity_seg,
"score_seg": score_seg,
"entity_ref": entity_ref,
# "score_ref": score_ref,
"word": word,
"start": entities[0]["start"],
"end": entities[-1]["end"],
}
return new_entity
def aggregate_words(self, entities: List[dict]) -> List[dict]:
"""
Override tokens from a given word that disagree to force agreement on word boundaries.
Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft|
company| B-ENT I-ENT
"""
word_entities = []
word_group = None
for entity in entities:
if word_group is None:
word_group = [entity]
elif entity["is_subword"]:
word_group.append(entity)
else:
word_entities.append(self.aggregate_word(word_group))
word_group = [entity]
word_entities.append(self.aggregate_word(word_group))
return word_entities
def group_entities(self, entities: List[dict]) -> List[dict]:
"""
Find and group together the adjacent tokens with the same entity predicted.
Args:
entities (`dict`): The entities predicted by the pipeline.
"""
entity_chunk = []
entity_chunk_disagg = []
for entity in entities:
if not entity_chunk_disagg:
entity_chunk_disagg.append(entity)
continue
bi_ref, tag_ref = self.get_tag(entity["entity_ref"])
last_bi_ref, last_tag_ref = self.get_tag(entity_chunk_disagg[-1]["entity_ref"])
if tag_ref == last_tag_ref and bi_ref != "B":
entity_chunk_disagg.append(entity)
else:
entity_chunk.append(entity_chunk_disagg)
entity_chunk_disagg = [entity]
if entity_chunk_disagg:
entity_chunk.append(entity_chunk_disagg)
entity_chunks_all = []
for chunk in entity_chunk:
entity_groups = []
entity_group_disagg = []
for entity in chunk:
if not entity_group_disagg:
entity_group_disagg.append(entity)
continue
bi_seg, tag_seg = self.get_tag(entity["entity_seg"])
last_bi_seg, last_tag_seg = self.get_tag(entity_group_disagg[-1]["entity_seg"])
if tag_seg == last_tag_seg and bi_seg != "B":
entity_group_disagg.append(entity)
else:
entity_groups.append(self.group_sub_entities(entity_group_disagg))
entity_group_disagg = [entity]
if entity_group_disagg:
entity_groups.append(self.group_sub_entities(entity_group_disagg))
entity_chunks_all.append(entity_groups)
return entity_chunks_all
def group_sub_entities(self, entities: List[dict]) -> dict:
"""
Group together the adjacent tokens with the same entity predicted.
Args:
entities (`dict`): The entities predicted by the pipeline.
"""
entity = entities[0]["entity_seg"].split("-")[-1]
scores = np.nanmean([entity["score_seg"] for entity in entities])
tokens = [entity["word"] for entity in entities]
entity_group = {
"entity_group": entity,
"score": np.mean(scores),
"word": " ".join(tokens),
"start": entities[0]["start"],
"end": entities[-1]["end"],
}
return entity_group
def get_tag(self, entity_name: str) -> Tuple[str, str]:
if entity_name.startswith("B-"):
bi = "B"
tag = entity_name[2:]
elif entity_name.startswith("I-"):
bi = "I"
tag = entity_name[2:]
else:
bi = "I"
tag = entity_name
return bi, tag