Spaces:
Running
on
A10G
Running
on
A10G
# This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
# ## Citations | |
# ```bibtex | |
# @inproceedings{yao2021wenet, | |
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
# booktitle={Proc. Interspeech}, | |
# year={2021}, | |
# address={Brno, Czech Republic }, | |
# organization={IEEE} | |
# } | |
# @article{zhang2022wenet, | |
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
# journal={arXiv preprint arXiv:2203.15455}, | |
# year={2022} | |
# } | |
# | |
"""Positionwise feed forward layer definition.""" | |
import torch | |
class PositionwiseFeedForward(torch.nn.Module): | |
"""Positionwise feed forward layer. | |
FeedForward are appied on each position of the sequence. | |
The output dim is same with the input dim. | |
Args: | |
idim (int): Input dimenstion. | |
hidden_units (int): The number of hidden units. | |
dropout_rate (float): Dropout rate. | |
activation (torch.nn.Module): Activation function | |
""" | |
def __init__( | |
self, | |
idim: int, | |
hidden_units: int, | |
dropout_rate: float, | |
activation: torch.nn.Module = torch.nn.ReLU(), | |
adaptive_scale: bool = False, | |
init_weights: bool = False, | |
): | |
"""Construct a PositionwiseFeedForward object.""" | |
super(PositionwiseFeedForward, self).__init__() | |
self.idim = idim | |
self.hidden_units = hidden_units | |
self.w_1 = torch.nn.Linear(idim, hidden_units) | |
self.activation = activation | |
self.dropout = torch.nn.Dropout(dropout_rate) | |
self.w_2 = torch.nn.Linear(hidden_units, idim) | |
self.ada_scale = None | |
self.ada_bias = None | |
self.adaptive_scale = adaptive_scale | |
self.ada_scale = torch.nn.Parameter( | |
torch.ones([1, 1, idim]), requires_grad=adaptive_scale | |
) | |
self.ada_bias = torch.nn.Parameter( | |
torch.zeros([1, 1, idim]), requires_grad=adaptive_scale | |
) | |
if init_weights: | |
self.init_weights() | |
def init_weights(self): | |
ffn1_max = self.idim**-0.5 | |
ffn2_max = self.hidden_units**-0.5 | |
torch.nn.init.uniform_(self.w_1.weight.data, -ffn1_max, ffn1_max) | |
torch.nn.init.uniform_(self.w_1.bias.data, -ffn1_max, ffn1_max) | |
torch.nn.init.uniform_(self.w_2.weight.data, -ffn2_max, ffn2_max) | |
torch.nn.init.uniform_(self.w_2.bias.data, -ffn2_max, ffn2_max) | |
def forward(self, xs: torch.Tensor) -> torch.Tensor: | |
"""Forward function. | |
Args: | |
xs: input tensor (B, L, D) | |
Returns: | |
output tensor, (B, L, D) | |
""" | |
if self.adaptive_scale: | |
xs = self.ada_scale * xs + self.ada_bias | |
return self.w_2(self.dropout(self.activation(self.w_1(xs)))) | |