from typing import * import torch from allennlp.modules.span_extractors import SpanExtractor @SpanExtractor.register('combo') class ComboSpanExtractor(SpanExtractor): def __init__(self, input_dim: int, sub_extractors: List[SpanExtractor]): super().__init__() self.sub_extractors = sub_extractors for i, sub in enumerate(sub_extractors): self.add_module(f'SpanExtractor-{i+1}', sub) self.input_dim = input_dim def get_input_dim(self) -> int: return self.input_dim def get_output_dim(self) -> int: return sum([sub.get_output_dim() for sub in self.sub_extractors]) def forward( self, sequence_tensor: torch.FloatTensor, span_indices: torch.LongTensor, sequence_mask: torch.BoolTensor = None, span_indices_mask: torch.BoolTensor = None, ): outputs = [ sub( sequence_tensor=sequence_tensor, span_indices=span_indices, span_indices_mask=span_indices_mask ) for sub in self.sub_extractors ] return torch.cat(outputs, dim=2)