saeki
commited on
Commit
·
200d40d
1
Parent(s):
61192e1
fix
Browse files- lightning_module.py +2 -85
- utils.py +0 -12
lightning_module.py
CHANGED
@@ -11,13 +11,6 @@ from model import (
|
|
11 |
MultiScaleSpectralLoss,
|
12 |
GSTModule,
|
13 |
)
|
14 |
-
from utils import (
|
15 |
-
manual_logging,
|
16 |
-
load_vocoder,
|
17 |
-
plot_and_save_mels,
|
18 |
-
plot_and_save_mels_all,
|
19 |
-
)
|
20 |
-
|
21 |
|
22 |
class PretrainLightningModule(pl.LightningModule):
|
23 |
def __init__(self, config):
|
@@ -32,7 +25,7 @@ class PretrainLightningModule(pl.LightningModule):
|
|
32 |
self.channelfeats = ChannelFeatureModule(config)
|
33 |
|
34 |
self.channel = ChannelModule(config)
|
35 |
-
self.vocoder =
|
36 |
|
37 |
self.criteria_a = MultiScaleSpectralLoss(config)
|
38 |
if "feature_loss" in config["train"]:
|
@@ -154,8 +147,6 @@ class PretrainLightningModule(pl.LightningModule):
|
|
154 |
prog_bar=True,
|
155 |
logger=True,
|
156 |
)
|
157 |
-
self.tflogger(logger_dict=outputs[-1]["logger_dict"][0], data_type="image")
|
158 |
-
self.tflogger(logger_dict=outputs[-1]["logger_dict"][1], data_type="audio")
|
159 |
|
160 |
def test_step(self, batch, batch_idx):
|
161 |
if self.config["general"]["use_gst"]:
|
@@ -224,24 +215,6 @@ class PretrainLightningModule(pl.LightningModule):
|
|
224 |
sample_rate=self.config["preprocess"]["sampling_rate"],
|
225 |
channels_first=True,
|
226 |
)
|
227 |
-
plot_and_save_mels(
|
228 |
-
out[key][0, ...].cpu(),
|
229 |
-
mel_dir / "{}-{}.png".format(idx, key),
|
230 |
-
self.config,
|
231 |
-
)
|
232 |
-
plot_and_save_mels_all(
|
233 |
-
out,
|
234 |
-
[
|
235 |
-
"reconstructed",
|
236 |
-
"remastered",
|
237 |
-
"channeled",
|
238 |
-
"input",
|
239 |
-
"input_recons",
|
240 |
-
"groundtruth",
|
241 |
-
],
|
242 |
-
mel_dir / "{}-all.png".format(idx),
|
243 |
-
self.config,
|
244 |
-
)
|
245 |
|
246 |
def configure_optimizers(self):
|
247 |
optimizer = torch.optim.Adam(
|
@@ -257,21 +230,6 @@ class PretrainLightningModule(pl.LightningModule):
|
|
257 |
}
|
258 |
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
|
259 |
|
260 |
-
def tflogger(self, logger_dict, data_type):
|
261 |
-
for lg in self.logger.experiment:
|
262 |
-
if type(lg).__name__ == "SummaryWriter":
|
263 |
-
tensorboard = lg
|
264 |
-
for key in logger_dict.keys():
|
265 |
-
manual_logging(
|
266 |
-
logger=tensorboard,
|
267 |
-
item=logger_dict[key],
|
268 |
-
idx=0,
|
269 |
-
tag=key,
|
270 |
-
global_step=self.global_step,
|
271 |
-
data_type=data_type,
|
272 |
-
config=self.config,
|
273 |
-
)
|
274 |
-
|
275 |
|
276 |
class SSLBaseModule(pl.LightningModule):
|
277 |
def __init__(self, config):
|
@@ -299,7 +257,7 @@ class SSLBaseModule(pl.LightningModule):
|
|
299 |
pre_model.channelfeats.state_dict(), strict=False
|
300 |
)
|
301 |
|
302 |
-
self.vocoder =
|
303 |
self.criteria = self.get_loss_function(config)
|
304 |
|
305 |
def training_step(self, batch, batch_idx):
|
@@ -405,32 +363,6 @@ class SSLBaseModule(pl.LightningModule):
|
|
405 |
sample_rate=self.config["preprocess"]["sampling_rate"],
|
406 |
channels_first=True,
|
407 |
)
|
408 |
-
plot_and_save_mels(
|
409 |
-
out[key][0, ...].cpu(),
|
410 |
-
mel_dir / "{}-{}.png".format(idx, key),
|
411 |
-
self.config,
|
412 |
-
)
|
413 |
-
plot_and_save_mels_all(
|
414 |
-
out,
|
415 |
-
plot_keys,
|
416 |
-
mel_dir / "{}-all.png".format(idx),
|
417 |
-
self.config,
|
418 |
-
)
|
419 |
-
|
420 |
-
def tflogger(self, logger_dict, data_type):
|
421 |
-
for lg in self.logger.experiment:
|
422 |
-
if type(lg).__name__ == "SummaryWriter":
|
423 |
-
tensorboard = lg
|
424 |
-
for key in logger_dict.keys():
|
425 |
-
manual_logging(
|
426 |
-
logger=tensorboard,
|
427 |
-
item=logger_dict[key],
|
428 |
-
idx=0,
|
429 |
-
tag=key,
|
430 |
-
global_step=self.global_step,
|
431 |
-
data_type=data_type,
|
432 |
-
config=self.config,
|
433 |
-
)
|
434 |
|
435 |
|
436 |
class SSLStepLightningModule(SSLBaseModule):
|
@@ -511,8 +443,6 @@ class SSLStepLightningModule(SSLBaseModule):
|
|
511 |
prog_bar=True,
|
512 |
logger=True,
|
513 |
)
|
514 |
-
self.tflogger(logger_dict=outputs[-1]["logger_dict"][0], data_type="image")
|
515 |
-
self.tflogger(logger_dict=outputs[-1]["logger_dict"][1], data_type="audio")
|
516 |
|
517 |
def optimizer_step(
|
518 |
self,
|
@@ -754,8 +684,6 @@ class SSLDualLightningModule(SSLBaseModule):
|
|
754 |
prog_bar=True,
|
755 |
logger=True,
|
756 |
)
|
757 |
-
self.tflogger(logger_dict=outputs[-1]["logger_dict"][0], data_type="image")
|
758 |
-
self.tflogger(logger_dict=outputs[-1]["logger_dict"][1], data_type="audio")
|
759 |
|
760 |
def test_step(self, batch, batch_idx):
|
761 |
if self.config["general"]["use_gst"]:
|
@@ -833,17 +761,6 @@ class SSLDualLightningModule(SSLBaseModule):
|
|
833 |
sample_rate=self.config["preprocess"]["sampling_rate"],
|
834 |
channels_first=True,
|
835 |
)
|
836 |
-
plot_and_save_mels(
|
837 |
-
out[key][0, ...].cpu(),
|
838 |
-
mel_dir / "{}-{}.png".format(idx, key),
|
839 |
-
self.config,
|
840 |
-
)
|
841 |
-
plot_and_save_mels_all(
|
842 |
-
out,
|
843 |
-
plot_keys,
|
844 |
-
mel_dir / "{}-all.png".format(idx),
|
845 |
-
self.config,
|
846 |
-
)
|
847 |
|
848 |
def configure_optimizers(self):
|
849 |
optimizer = torch.optim.Adam(
|
|
|
11 |
MultiScaleSpectralLoss,
|
12 |
GSTModule,
|
13 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
class PretrainLightningModule(pl.LightningModule):
|
16 |
def __init__(self, config):
|
|
|
25 |
self.channelfeats = ChannelFeatureModule(config)
|
26 |
|
27 |
self.channel = ChannelModule(config)
|
28 |
+
self.vocoder = None
|
29 |
|
30 |
self.criteria_a = MultiScaleSpectralLoss(config)
|
31 |
if "feature_loss" in config["train"]:
|
|
|
147 |
prog_bar=True,
|
148 |
logger=True,
|
149 |
)
|
|
|
|
|
150 |
|
151 |
def test_step(self, batch, batch_idx):
|
152 |
if self.config["general"]["use_gst"]:
|
|
|
215 |
sample_rate=self.config["preprocess"]["sampling_rate"],
|
216 |
channels_first=True,
|
217 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
def configure_optimizers(self):
|
220 |
optimizer = torch.optim.Adam(
|
|
|
230 |
}
|
231 |
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
|
232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
class SSLBaseModule(pl.LightningModule):
|
235 |
def __init__(self, config):
|
|
|
257 |
pre_model.channelfeats.state_dict(), strict=False
|
258 |
)
|
259 |
|
260 |
+
self.vocoder = None
|
261 |
self.criteria = self.get_loss_function(config)
|
262 |
|
263 |
def training_step(self, batch, batch_idx):
|
|
|
363 |
sample_rate=self.config["preprocess"]["sampling_rate"],
|
364 |
channels_first=True,
|
365 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
|
367 |
|
368 |
class SSLStepLightningModule(SSLBaseModule):
|
|
|
443 |
prog_bar=True,
|
444 |
logger=True,
|
445 |
)
|
|
|
|
|
446 |
|
447 |
def optimizer_step(
|
448 |
self,
|
|
|
684 |
prog_bar=True,
|
685 |
logger=True,
|
686 |
)
|
|
|
|
|
687 |
|
688 |
def test_step(self, batch, batch_idx):
|
689 |
if self.config["general"]["use_gst"]:
|
|
|
761 |
sample_rate=self.config["preprocess"]["sampling_rate"],
|
762 |
channels_first=True,
|
763 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
764 |
|
765 |
def configure_optimizers(self):
|
766 |
optimizer = torch.optim.Adam(
|
utils.py
CHANGED
@@ -3,18 +3,6 @@ import json
|
|
3 |
import torch
|
4 |
import torchaudio
|
5 |
|
6 |
-
def load_vocoder(config):
|
7 |
-
with open(
|
8 |
-
"hifigan/config_{}.json".format(config["general"]["feature_type"]), "r"
|
9 |
-
) as f:
|
10 |
-
config_hifigan = hifigan.AttrDict(json.load(f))
|
11 |
-
vocoder = hifigan.Generator(config_hifigan)
|
12 |
-
vocoder.load_state_dict(torch.load(config["general"]["hifigan_path"])["generator"])
|
13 |
-
vocoder.remove_weight_norm()
|
14 |
-
for param in vocoder.parameters():
|
15 |
-
param.requires_grad = False
|
16 |
-
return vocoder
|
17 |
-
|
18 |
def configure_args(config, args):
|
19 |
for key in ["stage", "corpus_type", "source_path", "aux_path", "preprocessed_path"]:
|
20 |
if getattr(args, key) != None:
|
|
|
3 |
import torch
|
4 |
import torchaudio
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
def configure_args(config, args):
|
7 |
for key in ["stage", "corpus_type", "source_path", "aux_path", "preprocessed_path"]:
|
8 |
if getattr(args, key) != None:
|