Spaces:
Build error
Build error
File size: 1,984 Bytes
6680682 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from typing import *
import torch
from .span import Span
def _tensor2span_batch(
span_boundary: torch.Tensor,
span_labels: torch.Tensor,
parent_indices: torch.Tensor,
num_spans: torch.Tensor,
label_confidence: torch.Tensor,
idx2label: Dict[int, str],
label_ignore: List[int],
) -> Span:
spans = list()
for (start_idx, end_idx), parent_idx, label, label_conf in \
list(zip(span_boundary, parent_indices, span_labels, label_confidence))[:int(num_spans)]:
if label not in label_ignore:
span = Span(int(start_idx), int(end_idx), idx2label[int(label)], True, confidence=float(label_conf))
if int(parent_idx) < len(spans):
spans[int(parent_idx)].add_child(span)
spans.append(span)
return spans[0]
def tensor2span(
span_boundary: torch.Tensor,
span_labels: torch.Tensor,
parent_indices: torch.Tensor,
num_spans: torch.Tensor,
label_confidence: torch.Tensor,
idx2label: Dict[int, str],
label_ignore: Optional[List[int]] = None,
) -> List[Span]:
"""
Generate spans in dict from vectors. Refer to the model part for the meaning of these variables.
If idx_ignore is provided, some labels will be ignored.
:return:
"""
label_ignore = label_ignore or []
if span_boundary.device.type != 'cpu':
span_boundary = span_boundary.to(device='cpu')
parent_indices = parent_indices.to(device='cpu')
span_labels = span_labels.to(device='cpu')
num_spans = num_spans.to(device='cpu')
label_confidence = label_confidence.to(device='cpu')
ret = list()
for args in zip(
span_boundary.unbind(0), span_labels.unbind(0), parent_indices.unbind(0), num_spans.unbind(0),
label_confidence.unbind(0),
):
ret.append(_tensor2span_batch(*args, label_ignore=label_ignore, idx2label=idx2label))
return ret
|