|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers.modeling_utils import ModuleUtilsMixin |
|
from transformers.models.t5.modeling_t5 import T5Block, T5Config, T5LayerNorm |
|
|
|
from ....configuration_utils import ConfigMixin, register_to_config |
|
from ....models import ModelMixin |
|
|
|
|
|
class SpectrogramNotesEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): |
|
@register_to_config |
|
def __init__( |
|
self, |
|
max_length: int, |
|
vocab_size: int, |
|
d_model: int, |
|
dropout_rate: float, |
|
num_layers: int, |
|
num_heads: int, |
|
d_kv: int, |
|
d_ff: int, |
|
feed_forward_proj: str, |
|
is_decoder: bool = False, |
|
): |
|
super().__init__() |
|
|
|
self.token_embedder = nn.Embedding(vocab_size, d_model) |
|
|
|
self.position_encoding = nn.Embedding(max_length, d_model) |
|
self.position_encoding.weight.requires_grad = False |
|
|
|
self.dropout_pre = nn.Dropout(p=dropout_rate) |
|
|
|
t5config = T5Config( |
|
vocab_size=vocab_size, |
|
d_model=d_model, |
|
num_heads=num_heads, |
|
d_kv=d_kv, |
|
d_ff=d_ff, |
|
dropout_rate=dropout_rate, |
|
feed_forward_proj=feed_forward_proj, |
|
is_decoder=is_decoder, |
|
is_encoder_decoder=False, |
|
) |
|
|
|
self.encoders = nn.ModuleList() |
|
for lyr_num in range(num_layers): |
|
lyr = T5Block(t5config) |
|
self.encoders.append(lyr) |
|
|
|
self.layer_norm = T5LayerNorm(d_model) |
|
self.dropout_post = nn.Dropout(p=dropout_rate) |
|
|
|
def forward(self, encoder_input_tokens, encoder_inputs_mask): |
|
x = self.token_embedder(encoder_input_tokens) |
|
|
|
seq_length = encoder_input_tokens.shape[1] |
|
inputs_positions = torch.arange(seq_length, device=encoder_input_tokens.device) |
|
x += self.position_encoding(inputs_positions) |
|
|
|
x = self.dropout_pre(x) |
|
|
|
|
|
input_shape = encoder_input_tokens.size() |
|
extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) |
|
|
|
for lyr in self.encoders: |
|
x = lyr(x, extended_attention_mask)[0] |
|
x = self.layer_norm(x) |
|
|
|
return self.dropout_post(x), encoder_inputs_mask |
|
|