File size: 2,329 Bytes
0883aa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from typing import Optional

import torch
from torch import nn
from modules.wenet_extractor.utils.common import get_activation


class TransducerJoint(torch.nn.Module):
    def __init__(
        self,
        voca_size: int,
        enc_output_size: int,
        pred_output_size: int,
        join_dim: int,
        prejoin_linear: bool = True,
        postjoin_linear: bool = False,
        joint_mode: str = "add",
        activation: str = "tanh",
    ):
        # TODO(Mddct): concat in future
        assert joint_mode in ["add"]
        super().__init__()

        self.activatoin = get_activation(activation)
        self.prejoin_linear = prejoin_linear
        self.postjoin_linear = postjoin_linear
        self.joint_mode = joint_mode

        if not self.prejoin_linear and not self.postjoin_linear:
            assert enc_output_size == pred_output_size == join_dim
        # torchscript compatibility
        self.enc_ffn: Optional[nn.Linear] = None
        self.pred_ffn: Optional[nn.Linear] = None
        if self.prejoin_linear:
            self.enc_ffn = nn.Linear(enc_output_size, join_dim)
            self.pred_ffn = nn.Linear(pred_output_size, join_dim)
        # torchscript compatibility
        self.post_ffn: Optional[nn.Linear] = None
        if self.postjoin_linear:
            self.post_ffn = nn.Linear(join_dim, join_dim)

        self.ffn_out = nn.Linear(join_dim, voca_size)

    def forward(self, enc_out: torch.Tensor, pred_out: torch.Tensor):
        """
        Args:
            enc_out (torch.Tensor): [B, T, E]
            pred_out (torch.Tensor): [B, T, P]
        Return:
            [B,T,U,V]
        """
        if (
            self.prejoin_linear
            and self.enc_ffn is not None
            and self.pred_ffn is not None
        ):
            enc_out = self.enc_ffn(enc_out)  # [B,T,E] -> [B,T,V]
            pred_out = self.pred_ffn(pred_out)

        enc_out = enc_out.unsqueeze(2)  # [B,T,V] -> [B,T,1,V]
        pred_out = pred_out.unsqueeze(1)  # [B,U,V] -> [B,1 U, V]

        # TODO(Mddct): concat joint
        _ = self.joint_mode
        out = enc_out + pred_out  # [B,T,U,V]

        if self.postjoin_linear and self.post_ffn is not None:
            out = self.post_ffn(out)

        out = self.activatoin(out)
        out = self.ffn_out(out)
        return out