File size: 3,487 Bytes
6680682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from allennlp.modules.conditional_random_field import ConditionalRandomField
from allennlp.nn.util import logsumexp
from overrides import overrides


class SmoothCRF(ConditionalRandomField):
    @overrides
    def forward(self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.Tensor = None):
        """

        :param inputs: Shape [batch, token, tag]
        :param tags: Shape [batch, token] or [batch, token, tag]
        :param mask: Shape [batch, token]
        :return:
        """
        if mask is None:
            mask = tags.new_ones(tags.shape, dtype=torch.bool)
        mask = mask.to(dtype=torch.bool)
        if tags.dim() == 2:
            return super(SmoothCRF, self).forward(inputs, tags, mask)

        # smooth mode
        log_denominator = self._input_likelihood(inputs, mask)
        log_numerator = self._smooth_joint_likelihood(inputs, tags, mask)

        return torch.sum(log_numerator - log_denominator)

    def _smooth_joint_likelihood(
        self, logits: torch.Tensor, soft_tags: torch.Tensor, mask: torch.Tensor
    ) -> torch.Tensor:
        batch_size, sequence_length, num_tags = logits.size()

        epsilon = 1e-30
        soft_tags = soft_tags.clone()
        soft_tags[soft_tags < epsilon] = epsilon

        # Transpose batch size and sequence dimensions
        mask = mask.transpose(0, 1).contiguous()
        logits = logits.transpose(0, 1).contiguous()
        soft_tags = soft_tags.transpose(0, 1).contiguous()

        # Initial alpha is the (batch_size, num_tags) tensor of likelihoods combining the
        # transitions to the initial states and the logits for the first timestep.
        if self.include_start_end_transitions:
            alpha = self.start_transitions.view(1, num_tags) + logits[0] + soft_tags[0].log()
        else:
            alpha = logits[0] * soft_tags[0]

        # For each i we compute logits for the transitions from timestep i-1 to timestep i.
        # We do so in a (batch_size, num_tags, num_tags) tensor where the axes are
        # (instance, current_tag, next_tag)
        for i in range(1, sequence_length):
            # The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis.
            emit_scores = logits[i].view(batch_size, 1, num_tags)
            # Transition scores are (current_tag, next_tag) so we broadcast along the instance axis.
            transition_scores = self.transitions.view(1, num_tags, num_tags)
            # Alpha is for the current_tag, so we broadcast along the next_tag axis.
            broadcast_alpha = alpha.view(batch_size, num_tags, 1)

            # Add all the scores together and logexp over the current_tag axis.
            inner = broadcast_alpha + emit_scores + transition_scores + soft_tags[i].log().unsqueeze(1)

            # In valid positions (mask == True) we want to take the logsumexp over the current_tag dimension
            # of `inner`. Otherwise (mask == False) we want to retain the previous alpha.
            alpha = logsumexp(inner, 1) * mask[i].view(batch_size, 1) + alpha * (
                ~mask[i]
            ).view(batch_size, 1)

        # Every sequence needs to end with a transition to the stop_tag.
        if self.include_start_end_transitions:
            stops = alpha + self.end_transitions.view(1, num_tags)
        else:
            stops = alpha

        # Finally we log_sum_exp along the num_tags dim, result is (batch_size,)
        return logsumexp(stops)