saeki commited on
Commit
200d40d
·
1 Parent(s): 61192e1
Files changed (2) hide show
  1. lightning_module.py +2 -85
  2. utils.py +0 -12
lightning_module.py CHANGED
@@ -11,13 +11,6 @@ from model import (
11
  MultiScaleSpectralLoss,
12
  GSTModule,
13
  )
14
- from utils import (
15
- manual_logging,
16
- load_vocoder,
17
- plot_and_save_mels,
18
- plot_and_save_mels_all,
19
- )
20
-
21
 
22
  class PretrainLightningModule(pl.LightningModule):
23
  def __init__(self, config):
@@ -32,7 +25,7 @@ class PretrainLightningModule(pl.LightningModule):
32
  self.channelfeats = ChannelFeatureModule(config)
33
 
34
  self.channel = ChannelModule(config)
35
- self.vocoder = load_vocoder(config)
36
 
37
  self.criteria_a = MultiScaleSpectralLoss(config)
38
  if "feature_loss" in config["train"]:
@@ -154,8 +147,6 @@ class PretrainLightningModule(pl.LightningModule):
154
  prog_bar=True,
155
  logger=True,
156
  )
157
- self.tflogger(logger_dict=outputs[-1]["logger_dict"][0], data_type="image")
158
- self.tflogger(logger_dict=outputs[-1]["logger_dict"][1], data_type="audio")
159
 
160
  def test_step(self, batch, batch_idx):
161
  if self.config["general"]["use_gst"]:
@@ -224,24 +215,6 @@ class PretrainLightningModule(pl.LightningModule):
224
  sample_rate=self.config["preprocess"]["sampling_rate"],
225
  channels_first=True,
226
  )
227
- plot_and_save_mels(
228
- out[key][0, ...].cpu(),
229
- mel_dir / "{}-{}.png".format(idx, key),
230
- self.config,
231
- )
232
- plot_and_save_mels_all(
233
- out,
234
- [
235
- "reconstructed",
236
- "remastered",
237
- "channeled",
238
- "input",
239
- "input_recons",
240
- "groundtruth",
241
- ],
242
- mel_dir / "{}-all.png".format(idx),
243
- self.config,
244
- )
245
 
246
  def configure_optimizers(self):
247
  optimizer = torch.optim.Adam(
@@ -257,21 +230,6 @@ class PretrainLightningModule(pl.LightningModule):
257
  }
258
  return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
259
 
260
- def tflogger(self, logger_dict, data_type):
261
- for lg in self.logger.experiment:
262
- if type(lg).__name__ == "SummaryWriter":
263
- tensorboard = lg
264
- for key in logger_dict.keys():
265
- manual_logging(
266
- logger=tensorboard,
267
- item=logger_dict[key],
268
- idx=0,
269
- tag=key,
270
- global_step=self.global_step,
271
- data_type=data_type,
272
- config=self.config,
273
- )
274
-
275
 
276
  class SSLBaseModule(pl.LightningModule):
277
  def __init__(self, config):
@@ -299,7 +257,7 @@ class SSLBaseModule(pl.LightningModule):
299
  pre_model.channelfeats.state_dict(), strict=False
300
  )
301
 
302
- self.vocoder = load_vocoder(config)
303
  self.criteria = self.get_loss_function(config)
304
 
305
  def training_step(self, batch, batch_idx):
@@ -405,32 +363,6 @@ class SSLBaseModule(pl.LightningModule):
405
  sample_rate=self.config["preprocess"]["sampling_rate"],
406
  channels_first=True,
407
  )
408
- plot_and_save_mels(
409
- out[key][0, ...].cpu(),
410
- mel_dir / "{}-{}.png".format(idx, key),
411
- self.config,
412
- )
413
- plot_and_save_mels_all(
414
- out,
415
- plot_keys,
416
- mel_dir / "{}-all.png".format(idx),
417
- self.config,
418
- )
419
-
420
- def tflogger(self, logger_dict, data_type):
421
- for lg in self.logger.experiment:
422
- if type(lg).__name__ == "SummaryWriter":
423
- tensorboard = lg
424
- for key in logger_dict.keys():
425
- manual_logging(
426
- logger=tensorboard,
427
- item=logger_dict[key],
428
- idx=0,
429
- tag=key,
430
- global_step=self.global_step,
431
- data_type=data_type,
432
- config=self.config,
433
- )
434
 
435
 
436
  class SSLStepLightningModule(SSLBaseModule):
@@ -511,8 +443,6 @@ class SSLStepLightningModule(SSLBaseModule):
511
  prog_bar=True,
512
  logger=True,
513
  )
514
- self.tflogger(logger_dict=outputs[-1]["logger_dict"][0], data_type="image")
515
- self.tflogger(logger_dict=outputs[-1]["logger_dict"][1], data_type="audio")
516
 
517
  def optimizer_step(
518
  self,
@@ -754,8 +684,6 @@ class SSLDualLightningModule(SSLBaseModule):
754
  prog_bar=True,
755
  logger=True,
756
  )
757
- self.tflogger(logger_dict=outputs[-1]["logger_dict"][0], data_type="image")
758
- self.tflogger(logger_dict=outputs[-1]["logger_dict"][1], data_type="audio")
759
 
760
  def test_step(self, batch, batch_idx):
761
  if self.config["general"]["use_gst"]:
@@ -833,17 +761,6 @@ class SSLDualLightningModule(SSLBaseModule):
833
  sample_rate=self.config["preprocess"]["sampling_rate"],
834
  channels_first=True,
835
  )
836
- plot_and_save_mels(
837
- out[key][0, ...].cpu(),
838
- mel_dir / "{}-{}.png".format(idx, key),
839
- self.config,
840
- )
841
- plot_and_save_mels_all(
842
- out,
843
- plot_keys,
844
- mel_dir / "{}-all.png".format(idx),
845
- self.config,
846
- )
847
 
848
  def configure_optimizers(self):
849
  optimizer = torch.optim.Adam(
 
11
  MultiScaleSpectralLoss,
12
  GSTModule,
13
  )
 
 
 
 
 
 
 
14
 
15
  class PretrainLightningModule(pl.LightningModule):
16
  def __init__(self, config):
 
25
  self.channelfeats = ChannelFeatureModule(config)
26
 
27
  self.channel = ChannelModule(config)
28
+ self.vocoder = None
29
 
30
  self.criteria_a = MultiScaleSpectralLoss(config)
31
  if "feature_loss" in config["train"]:
 
147
  prog_bar=True,
148
  logger=True,
149
  )
 
 
150
 
151
  def test_step(self, batch, batch_idx):
152
  if self.config["general"]["use_gst"]:
 
215
  sample_rate=self.config["preprocess"]["sampling_rate"],
216
  channels_first=True,
217
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  def configure_optimizers(self):
220
  optimizer = torch.optim.Adam(
 
230
  }
231
  return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  class SSLBaseModule(pl.LightningModule):
235
  def __init__(self, config):
 
257
  pre_model.channelfeats.state_dict(), strict=False
258
  )
259
 
260
+ self.vocoder = None
261
  self.criteria = self.get_loss_function(config)
262
 
263
  def training_step(self, batch, batch_idx):
 
363
  sample_rate=self.config["preprocess"]["sampling_rate"],
364
  channels_first=True,
365
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
 
368
  class SSLStepLightningModule(SSLBaseModule):
 
443
  prog_bar=True,
444
  logger=True,
445
  )
 
 
446
 
447
  def optimizer_step(
448
  self,
 
684
  prog_bar=True,
685
  logger=True,
686
  )
 
 
687
 
688
  def test_step(self, batch, batch_idx):
689
  if self.config["general"]["use_gst"]:
 
761
  sample_rate=self.config["preprocess"]["sampling_rate"],
762
  channels_first=True,
763
  )
 
 
 
 
 
 
 
 
 
 
 
764
 
765
  def configure_optimizers(self):
766
  optimizer = torch.optim.Adam(
utils.py CHANGED
@@ -3,18 +3,6 @@ import json
3
  import torch
4
  import torchaudio
5
 
6
- def load_vocoder(config):
7
- with open(
8
- "hifigan/config_{}.json".format(config["general"]["feature_type"]), "r"
9
- ) as f:
10
- config_hifigan = hifigan.AttrDict(json.load(f))
11
- vocoder = hifigan.Generator(config_hifigan)
12
- vocoder.load_state_dict(torch.load(config["general"]["hifigan_path"])["generator"])
13
- vocoder.remove_weight_norm()
14
- for param in vocoder.parameters():
15
- param.requires_grad = False
16
- return vocoder
17
-
18
  def configure_args(config, args):
19
  for key in ["stage", "corpus_type", "source_path", "aux_path", "preprocessed_path"]:
20
  if getattr(args, key) != None:
 
3
  import torch
4
  import torchaudio
5
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def configure_args(config, args):
7
  for key in ["stage", "corpus_type", "source_path", "aux_path", "preprocessed_path"]:
8
  if getattr(args, key) != None: