|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers.modeling_utils import ModuleUtilsMixin |
|
from transformers.models.t5.modeling_t5 import ( |
|
T5Block, |
|
T5Config, |
|
T5LayerNorm, |
|
) |
|
|
|
from ....configuration_utils import ConfigMixin, register_to_config |
|
from ....models import ModelMixin |
|
|
|
|
|
class SpectrogramContEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): |
|
@register_to_config |
|
def __init__( |
|
self, |
|
input_dims: int, |
|
targets_context_length: int, |
|
d_model: int, |
|
dropout_rate: float, |
|
num_layers: int, |
|
num_heads: int, |
|
d_kv: int, |
|
d_ff: int, |
|
feed_forward_proj: str, |
|
is_decoder: bool = False, |
|
): |
|
super().__init__() |
|
|
|
self.input_proj = nn.Linear(input_dims, d_model, bias=False) |
|
|
|
self.position_encoding = nn.Embedding(targets_context_length, d_model) |
|
self.position_encoding.weight.requires_grad = False |
|
|
|
self.dropout_pre = nn.Dropout(p=dropout_rate) |
|
|
|
t5config = T5Config( |
|
d_model=d_model, |
|
num_heads=num_heads, |
|
d_kv=d_kv, |
|
d_ff=d_ff, |
|
feed_forward_proj=feed_forward_proj, |
|
dropout_rate=dropout_rate, |
|
is_decoder=is_decoder, |
|
is_encoder_decoder=False, |
|
) |
|
self.encoders = nn.ModuleList() |
|
for lyr_num in range(num_layers): |
|
lyr = T5Block(t5config) |
|
self.encoders.append(lyr) |
|
|
|
self.layer_norm = T5LayerNorm(d_model) |
|
self.dropout_post = nn.Dropout(p=dropout_rate) |
|
|
|
def forward(self, encoder_inputs, encoder_inputs_mask): |
|
x = self.input_proj(encoder_inputs) |
|
|
|
|
|
max_positions = encoder_inputs.shape[1] |
|
input_positions = torch.arange(max_positions, device=encoder_inputs.device) |
|
|
|
seq_lens = encoder_inputs_mask.sum(-1) |
|
input_positions = torch.roll(input_positions.unsqueeze(0), tuple(seq_lens.tolist()), dims=0) |
|
x += self.position_encoding(input_positions) |
|
|
|
x = self.dropout_pre(x) |
|
|
|
|
|
input_shape = encoder_inputs.size() |
|
extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) |
|
|
|
for lyr in self.encoders: |
|
x = lyr(x, extended_attention_mask)[0] |
|
x = self.layer_norm(x) |
|
|
|
return self.dropout_post(x), encoder_inputs_mask |
|
|