Spaces:
Build error
Build error
import json | |
import logging | |
import random | |
from typing import * | |
import numpy as np | |
from allennlp.data.dataset_readers.dataset_reader import DatasetReader | |
from allennlp.data.fields import MetadataField | |
from allennlp.data.instance import Instance | |
from .span_reader import SpanReader | |
from ..utils import Span, VIRTUAL_ROOT, BIOSmoothing | |
logger = logging.getLogger(__name__) | |
class SRLDatasetReader(SpanReader): | |
def __init__( | |
self, | |
min_negative: int = 5, | |
negative_ratio: float = 1., | |
event_only: bool = False, | |
event_smoothing_factor: float = 0., | |
arg_smoothing_factor: float = 0., | |
# For Ontology Mapping | |
ontology_mapping_path: Optional[str] = None, | |
min_weight: float = 1e-2, | |
max_weight: float = 1.0, | |
**extra | |
): | |
super().__init__(**extra) | |
self.min_negative = min_negative | |
self.negative_ratio = negative_ratio | |
self.event_only = event_only | |
self.event_smooth_factor = event_smoothing_factor | |
self.arg_smooth_factor = arg_smoothing_factor | |
self.ontology_mapping = None | |
if ontology_mapping_path is not None: | |
self.ontology_mapping = json.load(open(ontology_mapping_path)) | |
for k1 in ['event', 'argument']: | |
for k2, weights in self.ontology_mapping['mapping'][k1].items(): | |
weights = np.array(weights) | |
weights[weights < min_weight] = 0.0 | |
weights[weights > max_weight] = max_weight | |
self.ontology_mapping['mapping'][k1][k2] = weights | |
self.ontology_mapping['mapping'][k1] = { | |
k2: weights for k2, weights in self.ontology_mapping['mapping'][k1].items() if weights.sum() > 1e-5 | |
} | |
vr_label = [0.] * len(self.ontology_mapping['target']['label']) | |
vr_label[self.ontology_mapping['target']['label'].index(VIRTUAL_ROOT)] = 1.0 | |
self.ontology_mapping['mapping']['event'][VIRTUAL_ROOT] = np.array(vr_label) | |
def _read(self, file_path: str) -> Iterable[Instance]: | |
all_lines = list(map(json.loads, open(file_path).readlines())) | |
if self.debug: | |
random.seed(1); random.shuffle(all_lines) | |
for line in all_lines: | |
ins = self.text_to_instance(**line) | |
if ins is not None: | |
yield ins | |
if self.n_span_removed > 0: | |
logger.warning(f'{self.n_span_removed} spans are removed.') | |
self.n_span_removed = 0 | |
def apply_ontology_mapping(self, vr): | |
new_events = list() | |
event_map, arg_map = self.ontology_mapping['mapping']['event'], self.ontology_mapping['mapping']['argument'] | |
for event in vr: | |
if event.label not in event_map: continue | |
event.child_smooth.weight = event.smooth_weight = event_map[event.label].sum() | |
event = event.map_ontology(event_map, False, False) | |
new_events.append(event) | |
new_children = list() | |
for child in event: | |
if child.label not in arg_map: continue | |
child.child_smooth.weight = child.smooth_weight = arg_map[child.label].sum() | |
child = child.map_ontology(arg_map, False, False) | |
new_children.append(child) | |
event.remove_child() | |
for child in new_children: event.add_child(child) | |
new_vr = Span.virtual_root(new_events) | |
# For Virtual Root itself. | |
new_vr.map_ontology(self.ontology_mapping['mapping']['event'], True, False) | |
return new_vr | |
def text_to_instance(self, tokens, annotations=None, meta=None) -> Optional[Instance]: | |
meta = meta or {'fully_annotated': True} | |
meta['fully_annotated'] = meta.get('fully_annotated', True) | |
vr = None | |
if annotations is not None: | |
vr = annotations if isinstance(annotations, Span) else Span.from_json(annotations) | |
vr = self.apply_ontology_mapping(vr) if self.ontology_mapping is not None else vr | |
# if len(vr) == 0: return # Ignore sentence with empty annotation | |
if self.event_smooth_factor != 0.0: | |
vr.child_smooth = BIOSmoothing(o_smooth=self.event_smooth_factor if meta['fully_annotated'] else -1) | |
if self.arg_smooth_factor != 0.0: | |
for event in vr: | |
event.child_smooth = BIOSmoothing(o_smooth=self.arg_smooth_factor) | |
if self.event_only: | |
for event in vr: | |
event.remove_child() | |
event.is_parent = False | |
fields = self.prepare_inputs(tokens, vr, True, 'string' if self.ontology_mapping is None else 'list') | |
fields['meta'] = MetadataField(meta) | |
return Instance(fields) | |