conex / espnet2 /diar /decoder /linear_decoder.py
tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
No virus
755 Bytes
import torch
from espnet2.diar.decoder.abs_decoder import AbsDecoder
class LinearDecoder(AbsDecoder):
"""Linear decoder for speaker diarization """
def __init__(
self,
encoder_output_size: int,
num_spk: int = 2,
):
super().__init__()
self._num_spk = num_spk
self.linear_decoder = torch.nn.Linear(encoder_output_size, num_spk)
def forward(self, input: torch.Tensor, ilens: torch.Tensor):
"""Forward.
Args:
input (torch.Tensor): hidden_space [Batch, T, F]
ilens (torch.Tensor): input lengths [Batch]
"""
output = self.linear_decoder(input)
return output
@property
def num_spk(self):
return self._num_spk