zyingt's picture
Upload 685 files
0d80816
raw
history blame contribute delete
No virus
2.33 kB
from typing import Optional
import torch
from torch import nn
from modules.wenet_extractor.utils.common import get_activation
class TransducerJoint(torch.nn.Module):
def __init__(
self,
voca_size: int,
enc_output_size: int,
pred_output_size: int,
join_dim: int,
prejoin_linear: bool = True,
postjoin_linear: bool = False,
joint_mode: str = "add",
activation: str = "tanh",
):
# TODO(Mddct): concat in future
assert joint_mode in ["add"]
super().__init__()
self.activatoin = get_activation(activation)
self.prejoin_linear = prejoin_linear
self.postjoin_linear = postjoin_linear
self.joint_mode = joint_mode
if not self.prejoin_linear and not self.postjoin_linear:
assert enc_output_size == pred_output_size == join_dim
# torchscript compatibility
self.enc_ffn: Optional[nn.Linear] = None
self.pred_ffn: Optional[nn.Linear] = None
if self.prejoin_linear:
self.enc_ffn = nn.Linear(enc_output_size, join_dim)
self.pred_ffn = nn.Linear(pred_output_size, join_dim)
# torchscript compatibility
self.post_ffn: Optional[nn.Linear] = None
if self.postjoin_linear:
self.post_ffn = nn.Linear(join_dim, join_dim)
self.ffn_out = nn.Linear(join_dim, voca_size)
def forward(self, enc_out: torch.Tensor, pred_out: torch.Tensor):
"""
Args:
enc_out (torch.Tensor): [B, T, E]
pred_out (torch.Tensor): [B, T, P]
Return:
[B,T,U,V]
"""
if (
self.prejoin_linear
and self.enc_ffn is not None
and self.pred_ffn is not None
):
enc_out = self.enc_ffn(enc_out) # [B,T,E] -> [B,T,V]
pred_out = self.pred_ffn(pred_out)
enc_out = enc_out.unsqueeze(2) # [B,T,V] -> [B,T,1,V]
pred_out = pred_out.unsqueeze(1) # [B,U,V] -> [B,1 U, V]
# TODO(Mddct): concat joint
_ = self.joint_mode
out = enc_out + pred_out # [B,T,U,V]
if self.postjoin_linear and self.post_ffn is not None:
out = self.post_ffn(out)
out = self.activatoin(out)
out = self.ffn_out(out)
return out