"""Transducer joint network implementation.""" import torch from espnet.nets.pytorch_backend.nets_utils import get_activation class JointNetwork(torch.nn.Module): """Transducer joint network module. Args: joint_space_size: Dimension of joint space joint_activation_type: Activation type for joint network """ def __init__( self, vocab_size: int, encoder_output_size: int, decoder_output_size: int, joint_space_size: int, joint_activation_type: int, ): """Joint network initializer.""" super().__init__() self.lin_enc = torch.nn.Linear(encoder_output_size, joint_space_size) self.lin_dec = torch.nn.Linear( decoder_output_size, joint_space_size, bias=False ) self.lin_out = torch.nn.Linear(joint_space_size, vocab_size) self.joint_activation = get_activation(joint_activation_type) def forward( self, h_enc: torch.Tensor, h_dec: torch.Tensor, is_aux: bool = False ) -> torch.Tensor: """Joint computation of z. Args: h_enc: Batch of expanded hidden state (B, T, 1, D_enc) h_dec: Batch of expanded hidden state (B, 1, U, D_dec) Returns: z: Output (B, T, U, vocab_size) """ if is_aux: z = self.joint_activation(h_enc + self.lin_dec(h_dec)) else: z = self.joint_activation(self.lin_enc(h_enc) + self.lin_dec(h_dec)) z = self.lin_out(z) return z