import torch from espnet2.diar.decoder.abs_decoder import AbsDecoder class LinearDecoder(AbsDecoder): """Linear decoder for speaker diarization """ def __init__( self, encoder_output_size: int, num_spk: int = 2, ): super().__init__() self._num_spk = num_spk self.linear_decoder = torch.nn.Linear(encoder_output_size, num_spk) def forward(self, input: torch.Tensor, ilens: torch.Tensor): """Forward. Args: input (torch.Tensor): hidden_space [Batch, T, F] ilens (torch.Tensor): input lengths [Batch] """ output = self.linear_decoder(input) return output @property def num_spk(self): return self._num_spk