|
import torch |
|
from allennlp.modules.conditional_random_field import ConditionalRandomField |
|
from allennlp.nn.util import logsumexp |
|
from overrides import overrides |
|
|
|
|
|
class SmoothCRF(ConditionalRandomField): |
|
@overrides |
|
def forward(self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.Tensor = None): |
|
""" |
|
|
|
:param inputs: Shape [batch, token, tag] |
|
:param tags: Shape [batch, token] or [batch, token, tag] |
|
:param mask: Shape [batch, token] |
|
:return: |
|
""" |
|
if mask is None: |
|
mask = tags.new_ones(tags.shape, dtype=torch.bool) |
|
mask = mask.to(dtype=torch.bool) |
|
if tags.dim() == 2: |
|
return super(SmoothCRF, self).forward(inputs, tags, mask) |
|
|
|
|
|
log_denominator = self._input_likelihood(inputs, mask) |
|
log_numerator = self._smooth_joint_likelihood(inputs, tags, mask) |
|
|
|
return torch.sum(log_numerator - log_denominator) |
|
|
|
def _smooth_joint_likelihood( |
|
self, logits: torch.Tensor, soft_tags: torch.Tensor, mask: torch.Tensor |
|
) -> torch.Tensor: |
|
batch_size, sequence_length, num_tags = logits.size() |
|
|
|
epsilon = 1e-30 |
|
soft_tags = soft_tags.clone() |
|
soft_tags[soft_tags < epsilon] = epsilon |
|
|
|
|
|
mask = mask.transpose(0, 1).contiguous() |
|
logits = logits.transpose(0, 1).contiguous() |
|
soft_tags = soft_tags.transpose(0, 1).contiguous() |
|
|
|
|
|
|
|
if self.include_start_end_transitions: |
|
alpha = self.start_transitions.view(1, num_tags) + logits[0] + soft_tags[0].log() |
|
else: |
|
alpha = logits[0] * soft_tags[0] |
|
|
|
|
|
|
|
|
|
for i in range(1, sequence_length): |
|
|
|
emit_scores = logits[i].view(batch_size, 1, num_tags) |
|
|
|
transition_scores = self.transitions.view(1, num_tags, num_tags) |
|
|
|
broadcast_alpha = alpha.view(batch_size, num_tags, 1) |
|
|
|
|
|
inner = broadcast_alpha + emit_scores + transition_scores + soft_tags[i].log().unsqueeze(1) |
|
|
|
|
|
|
|
alpha = logsumexp(inner, 1) * mask[i].view(batch_size, 1) + alpha * ( |
|
~mask[i] |
|
).view(batch_size, 1) |
|
|
|
|
|
if self.include_start_end_transitions: |
|
stops = alpha + self.end_transitions.view(1, num_tags) |
|
else: |
|
stops = alpha |
|
|
|
|
|
return logsumexp(stops) |
|
|