File size: 1,984 Bytes
6680682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from typing import *

import torch

from .span import Span


def _tensor2span_batch(
        span_boundary: torch.Tensor,
        span_labels: torch.Tensor,
        parent_indices: torch.Tensor,
        num_spans: torch.Tensor,
        label_confidence: torch.Tensor,
        idx2label: Dict[int, str],
        label_ignore: List[int],
) -> Span:
    spans = list()
    for (start_idx, end_idx), parent_idx, label, label_conf in \
            list(zip(span_boundary, parent_indices, span_labels, label_confidence))[:int(num_spans)]:
        if label not in label_ignore:
            span = Span(int(start_idx), int(end_idx), idx2label[int(label)], True, confidence=float(label_conf))
            if int(parent_idx) < len(spans):
                spans[int(parent_idx)].add_child(span)
            spans.append(span)
    return spans[0]


def tensor2span(
        span_boundary: torch.Tensor,
        span_labels: torch.Tensor,
        parent_indices: torch.Tensor,
        num_spans: torch.Tensor,
        label_confidence: torch.Tensor,
        idx2label: Dict[int, str],
        label_ignore: Optional[List[int]] = None,
) -> List[Span]:
    """
    Generate spans in dict from vectors. Refer to the model part for the meaning of these variables.
    If idx_ignore is provided, some labels will be ignored.
    :return:
    """
    label_ignore = label_ignore or []
    if span_boundary.device.type != 'cpu':
        span_boundary = span_boundary.to(device='cpu')
        parent_indices = parent_indices.to(device='cpu')
        span_labels = span_labels.to(device='cpu')
        num_spans = num_spans.to(device='cpu')
        label_confidence = label_confidence.to(device='cpu')

    ret = list()
    for args in zip(
            span_boundary.unbind(0), span_labels.unbind(0), parent_indices.unbind(0), num_spans.unbind(0),
            label_confidence.unbind(0),
    ):
        ret.append(_tensor2span_batch(*args, label_ignore=label_ignore, idx2label=idx2label))

    return ret