|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class PitchRNN(nn.Module):
|
|
def __init__(self, n_mels, hidden_size):
|
|
super(PitchRNN, self).__init__()
|
|
|
|
self.sp_linear = nn.Sequential(nn.Conv1d(n_mels, hidden_size*2, kernel_size=1),
|
|
nn.SiLU(),
|
|
nn.Conv1d(hidden_size*2, hidden_size, kernel_size=1),
|
|
nn.SiLU(),)
|
|
|
|
self.midi_linear = nn.Sequential(nn.Conv1d(1, hidden_size*2, kernel_size=1),
|
|
nn.SiLU(),
|
|
nn.Conv1d(hidden_size*2, hidden_size, kernel_size=1),
|
|
nn.SiLU(),)
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
self.rnn = nn.GRU(input_size=hidden_size*2,
|
|
hidden_size=hidden_size,
|
|
num_layers=2,
|
|
batch_first=True,
|
|
bidirectional=True)
|
|
|
|
|
|
self.linear = nn.Sequential(nn.Linear(2*hidden_size, hidden_size),
|
|
nn.SiLU(),
|
|
nn.Linear(hidden_size, 1))
|
|
|
|
def forward(self, midi, sp):
|
|
midi = midi.unsqueeze(1)
|
|
midi = self.midi_linear(midi)
|
|
sp = self.sp_linear(sp)
|
|
|
|
x = torch.cat([midi, sp], dim=1)
|
|
x = torch.transpose(x, 1, 2)
|
|
x, _ = self.rnn(x)
|
|
|
|
|
|
x = self.linear(x)
|
|
|
|
return x.squeeze(-1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
model = PitchRNN(100, 256)
|
|
|
|
x = torch.rand((4, 128))
|
|
t = torch.randint(0, 1000, (1, )).long()
|
|
sp = torch.rand((4, 100, 128))
|
|
midi = torch.rand((4, 128))
|
|
|
|
y = model(midi, sp) |