update code
Browse files- main.py +2 -2
- models/blocks.py +1 -1
- models/frn.py +2 -2
main.py
CHANGED
@@ -4,7 +4,7 @@ import os
|
|
4 |
import pytorch_lightning as pl
|
5 |
import soundfile as sf
|
6 |
import torch
|
7 |
-
from pytorch_lightning.callbacks import ModelCheckpoint
|
8 |
from pytorch_lightning.utilities.model_summary import summarize
|
9 |
from torch.utils.data import DataLoader
|
10 |
|
@@ -66,7 +66,7 @@ def train():
|
|
66 |
gpus=len(gpus),
|
67 |
max_epochs=CONFIG.TRAIN.epochs,
|
68 |
accelerator="gpu" if len(gpus) > 1 else None,
|
69 |
-
callbacks=[checkpoint_callback
|
70 |
)
|
71 |
|
72 |
print(model.hparams)
|
|
|
4 |
import pytorch_lightning as pl
|
5 |
import soundfile as sf
|
6 |
import torch
|
7 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
8 |
from pytorch_lightning.utilities.model_summary import summarize
|
9 |
from torch.utils.data import DataLoader
|
10 |
|
|
|
66 |
gpus=len(gpus),
|
67 |
max_epochs=CONFIG.TRAIN.epochs,
|
68 |
accelerator="gpu" if len(gpus) > 1 else None,
|
69 |
+
callbacks=[checkpoint_callback]
|
70 |
)
|
71 |
|
72 |
print(model.hparams)
|
models/blocks.py
CHANGED
@@ -117,7 +117,7 @@ class Predictor(pl.LightningModule): # mel
|
|
117 |
fb = librosa.filters.mel(sr=sr, n_fft=self.window_size, n_mels=self.n_mels)[:, 1:]
|
118 |
self.fb = torch.from_numpy(fb).unsqueeze(0).unsqueeze(0)
|
119 |
self.lstm = nn.LSTM(input_size=self.n_mels, hidden_size=self.lstm_dim, bidirectional=False,
|
120 |
-
num_layers=self.lstm_layers)
|
121 |
self.expand_dim = nn.Linear(self.lstm_dim, self.n_mels)
|
122 |
self.inv_mel = nn.Linear(self.n_mels, self.hop_size)
|
123 |
|
|
|
117 |
fb = librosa.filters.mel(sr=sr, n_fft=self.window_size, n_mels=self.n_mels)[:, 1:]
|
118 |
self.fb = torch.from_numpy(fb).unsqueeze(0).unsqueeze(0)
|
119 |
self.lstm = nn.LSTM(input_size=self.n_mels, hidden_size=self.lstm_dim, bidirectional=False,
|
120 |
+
num_layers=self.lstm_layers, batch_first=True)
|
121 |
self.expand_dim = nn.Linear(self.lstm_dim, self.n_mels)
|
122 |
self.inv_mel = nn.Linear(self.n_mels, self.hop_size)
|
123 |
|
models/frn.py
CHANGED
@@ -66,7 +66,7 @@ class PLCModel(pl.LightningModule):
|
|
66 |
|
67 |
x = x.permute(3, 0, 1, 2).unsqueeze(-1)
|
68 |
prev_mag = torch.zeros((B, 1, F, 1), device=x.device)
|
69 |
-
predictor_state = torch.zeros((2, self.predictor.lstm_layers,
|
70 |
mlp_state = torch.zeros((self.encoder.depth, 2, 1, B, self.encoder.dim), device=x.device)
|
71 |
result = []
|
72 |
for step in x:
|
@@ -201,7 +201,7 @@ class OnnxWrapper(pl.LightningModule):
|
|
201 |
super().__init__(*args, **kwargs)
|
202 |
self.model = model
|
203 |
batch_size = 1
|
204 |
-
pred_states = torch.zeros((2, 1,
|
205 |
mlp_states = torch.zeros((model.encoder.depth, 2, 1, batch_size, model.encoder.dim))
|
206 |
mag = torch.zeros((batch_size, 1, model.hop_size, 1))
|
207 |
x = torch.randn(batch_size, model.hop_size + 1, 2)
|
|
|
66 |
|
67 |
x = x.permute(3, 0, 1, 2).unsqueeze(-1)
|
68 |
prev_mag = torch.zeros((B, 1, F, 1), device=x.device)
|
69 |
+
predictor_state = torch.zeros((2, self.predictor.lstm_layers, B, self.predictor.lstm_dim), device=x.device)
|
70 |
mlp_state = torch.zeros((self.encoder.depth, 2, 1, B, self.encoder.dim), device=x.device)
|
71 |
result = []
|
72 |
for step in x:
|
|
|
201 |
super().__init__(*args, **kwargs)
|
202 |
self.model = model
|
203 |
batch_size = 1
|
204 |
+
pred_states = torch.zeros((2, 1, batch_size, model.predictor.lstm_dim))
|
205 |
mlp_states = torch.zeros((model.encoder.depth, 2, 1, batch_size, model.encoder.dim))
|
206 |
mag = torch.zeros((batch_size, 1, model.hop_size, 1))
|
207 |
x = torch.randn(batch_size, model.hop_size + 1, 2)
|