aet_demo / lightning_module.py
saeki
fix
200d40d
raw
history blame
31.6 kB
import torch
import pytorch_lightning as pl
import torchaudio
import os
import pathlib
import tqdm
from model import (
EncoderModule,
ChannelFeatureModule,
ChannelModule,
MultiScaleSpectralLoss,
GSTModule,
)
class PretrainLightningModule(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.save_hyperparameters()
self.config = config
if config["general"]["use_gst"]:
self.encoder = EncoderModule(config)
self.gst = GSTModule(config)
else:
self.encoder = EncoderModule(config, use_channel=True)
self.channelfeats = ChannelFeatureModule(config)
self.channel = ChannelModule(config)
self.vocoder = None
self.criteria_a = MultiScaleSpectralLoss(config)
if "feature_loss" in config["train"]:
if config["train"]["feature_loss"]["type"] == "mae":
self.criteria_b = torch.nn.L1Loss()
else:
self.criteria_b = torch.nn.MSELoss()
else:
self.criteria = torch.nn.L1Loss()
self.alpha = config["train"]["alpha"]
def forward(self, melspecs, wavsaux):
if self.config["general"]["use_gst"]:
enc_out = self.encoder(melspecs.unsqueeze(1).transpose(2, 3))
chfeats = self.gst(melspecs.transpose(1, 2))
else:
enc_out, enc_hidden = self.encoder(melspecs.unsqueeze(1).transpose(2, 3))
chfeats = self.channelfeats(enc_hidden)
enc_out = enc_out.squeeze(1).transpose(1, 2)
wavsdeg = self.channel(wavsaux, chfeats)
return enc_out, wavsdeg
def training_step(self, batch, batch_idx):
if self.config["general"]["use_gst"]:
enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
chfeats = self.gst(batch["melspecs"].transpose(1, 2))
else:
enc_out, enc_hidden = self.encoder(
batch["melspecs"].unsqueeze(1).transpose(2, 3)
)
chfeats = self.channelfeats(enc_hidden)
enc_out = enc_out.squeeze(1).transpose(1, 2)
wavsdeg = self.channel(batch["wavsaux"], chfeats)
loss_recons = self.criteria_a(wavsdeg, batch["wavs"])
if self.config["general"]["feature_type"] == "melspec":
loss_encoder = self.criteria_b(enc_out, batch["melspecsaux"])
elif self.config["general"]["feature_type"] == "vocfeats":
loss_encoder = self.criteria_b(enc_out, batch["melceps"])
loss = self.alpha * loss_recons + (1.0 - self.alpha) * loss_encoder
self.log(
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
)
self.log(
"train_loss_recons",
loss_recons,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
self.log(
"train_loss_encoder",
loss_encoder,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss
def validation_step(self, batch, batch_idx):
if self.config["general"]["use_gst"]:
enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
chfeats = self.gst(batch["melspecs"].transpose(1, 2))
else:
enc_out, enc_hidden = self.encoder(
batch["melspecs"].unsqueeze(1).transpose(2, 3)
)
chfeats = self.channelfeats(enc_hidden)
enc_out = enc_out.squeeze(1).transpose(1, 2)
wavsdeg = self.channel(batch["wavsaux"], chfeats)
loss_recons = self.criteria_a(wavsdeg, batch["wavs"])
if self.config["general"]["feature_type"] == "melspec":
val_aux_feats = batch["melspecsaux"]
feats_name = "melspec"
loss_encoder = self.criteria_b(enc_out, val_aux_feats)
elif self.config["general"]["feature_type"] == "vocfeats":
val_aux_feats = batch["melceps"]
feats_name = "melcep"
loss_encoder = self.criteria_b(enc_out, val_aux_feats)
loss = self.alpha * loss_recons + (1.0 - self.alpha) * loss_encoder
logger_img_dict = {
"val_src_melspec": batch["melspecs"],
"val_pred_{}".format(feats_name): enc_out,
"val_aux_{}".format(feats_name): val_aux_feats,
}
logger_wav_dict = {
"val_src_wav": batch["wavs"],
"val_pred_wav": wavsdeg,
"val_aux_wav": batch["wavsaux"],
}
return {
"val_loss": loss,
"val_loss_recons": loss_recons,
"val_loss_encoder": loss_encoder,
"logger_dict": [logger_img_dict, logger_wav_dict],
}
def validation_epoch_end(self, outputs):
val_loss = torch.stack([out["val_loss"] for out in outputs]).mean().item()
val_loss_recons = (
torch.stack([out["val_loss_recons"] for out in outputs]).mean().item()
)
val_loss_encoder = (
torch.stack([out["val_loss_encoder"] for out in outputs]).mean().item()
)
self.log("val_loss", val_loss, on_epoch=True, prog_bar=True, logger=True)
self.log(
"val_loss_recons",
val_loss_recons,
on_epoch=True,
prog_bar=True,
logger=True,
)
self.log(
"val_loss_encoder",
val_loss_encoder,
on_epoch=True,
prog_bar=True,
logger=True,
)
def test_step(self, batch, batch_idx):
if self.config["general"]["use_gst"]:
enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
chfeats = self.gst(batch["melspecs"].transpose(1, 2))
else:
enc_out, enc_hidden = self.encoder(
batch["melspecs"].unsqueeze(1).transpose(2, 3)
)
chfeats = self.channelfeats(enc_hidden)
enc_out = enc_out.squeeze(1).transpose(1, 2)
wavsdeg = self.channel(batch["wavsaux"], chfeats)
if self.config["general"]["feature_type"] == "melspec":
enc_feats = enc_out
enc_feats_aux = batch["melspecsaux"]
elif self.config["general"]["feature_type"] == "vocfeats":
enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
enc_feats_aux = torch.cat(
(batch["f0s"].unsqueeze(1), batch["melceps"]), dim=1
)
recons_wav = self.vocoder(enc_feats_aux).squeeze(1)
remas = self.vocoder(enc_feats).squeeze(1)
if self.config["general"]["feature_type"] == "melspec":
enc_feats_input = batch["melspecs"]
elif self.config["general"]["feature_type"] == "vocfeats":
enc_feats_input = torch.cat(
(batch["f0s"].unsqueeze(1), batch["melcepssrc"]), dim=1
)
input_recons = self.vocoder(enc_feats_input).squeeze(1)
if "wavsaux" in batch:
gt_wav = batch["wavsaux"]
else:
gt_wav = None
return {
"reconstructed": recons_wav,
"remastered": remas,
"channeled": wavsdeg,
"groundtruth": gt_wav,
"input": batch["wavs"],
"input_recons": input_recons,
}
def test_epoch_end(self, outputs):
wav_dir = (
pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_wavs"
)
os.makedirs(wav_dir, exist_ok=True)
mel_dir = (
pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_mels"
)
os.makedirs(mel_dir, exist_ok=True)
print("Saving mel spectrogram plots ...")
for idx, out in enumerate(tqdm.tqdm(outputs)):
for key in [
"reconstructed",
"remastered",
"channeled",
"input",
"input_recons",
"groundtruth",
]:
if out[key] != None:
torchaudio.save(
wav_dir / "{}-{}.wav".format(idx, key),
out[key][0, ...].unsqueeze(0).cpu(),
sample_rate=self.config["preprocess"]["sampling_rate"],
channels_first=True,
)
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(), lr=self.config["train"]["learning_rate"]
)
lr_scheduler_config = {
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, min_lr=1e-5, verbose=True
),
"interval": "epoch",
"frequency": 3,
"monitor": "val_loss",
}
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
class SSLBaseModule(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.save_hyperparameters()
self.config = config
if config["general"]["use_gst"]:
self.encoder = EncoderModule(config)
self.gst = GSTModule(config)
else:
self.encoder = EncoderModule(config, use_channel=True)
self.channelfeats = ChannelFeatureModule(config)
self.channel = ChannelModule(config)
if config["train"]["load_pretrained"]:
pre_model = PretrainLightningModule.load_from_checkpoint(
checkpoint_path=config["train"]["pretrained_path"]
)
self.encoder.load_state_dict(pre_model.encoder.state_dict(), strict=False)
self.channel.load_state_dict(pre_model.channel.state_dict(), strict=False)
if config["general"]["use_gst"]:
self.gst.load_state_dict(pre_model.gst.state_dict(), strict=False)
else:
self.channelfeats.load_state_dict(
pre_model.channelfeats.state_dict(), strict=False
)
self.vocoder = None
self.criteria = self.get_loss_function(config)
def training_step(self, batch, batch_idx):
raise NotImplementedError()
def validation_step(self, batch, batch_idx):
raise NotImplementedError()
def validation_epoch_end(self, outputs):
raise NotImplementedError()
def configure_optimizers(self):
raise NotImplementedError()
def get_loss_function(self, config):
raise NotImplementedError()
def forward(self, melspecs, f0s=None):
if self.config["general"]["use_gst"]:
enc_out = self.encoder(melspecs.unsqueeze(1).transpose(2, 3))
chfeats = self.gst(melspecs.transpose(1, 2))
else:
enc_out, enc_hidden = self.encoder(melspecs.unsqueeze(1).transpose(2, 3))
chfeats = self.channelfeats(enc_hidden)
enc_out = enc_out.squeeze(1).transpose(1, 2)
if self.config["general"]["feature_type"] == "melspec":
enc_feats = enc_out
elif self.config["general"]["feature_type"] == "vocfeats":
enc_feats = torch.cat((f0s.unsqueeze(1), enc_out), dim=1)
remas = self.vocoder(enc_feats).squeeze(1)
wavsdeg = self.channel(remas, chfeats)
return remas, wavsdeg
def test_step(self, batch, batch_idx):
if self.config["general"]["use_gst"]:
enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
chfeats = self.gst(batch["melspecs"].transpose(1, 2))
else:
enc_out, enc_hidden = self.encoder(
batch["melspecs"].unsqueeze(1).transpose(2, 3)
)
chfeats = self.channelfeats(enc_hidden)
enc_out = enc_out.squeeze(1).transpose(1, 2)
if self.config["general"]["feature_type"] == "melspec":
enc_feats = enc_out
elif self.config["general"]["feature_type"] == "vocfeats":
enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
remas = self.vocoder(enc_feats).squeeze(1)
wavsdeg = self.channel(remas, chfeats)
if self.config["general"]["feature_type"] == "melspec":
enc_feats_input = batch["melspecs"]
elif self.config["general"]["feature_type"] == "vocfeats":
enc_feats_input = torch.cat(
(batch["f0s"].unsqueeze(1), batch["melcepssrc"]), dim=1
)
input_recons = self.vocoder(enc_feats_input).squeeze(1)
if "wavsaux" in batch:
gt_wav = batch["wavsaux"]
if self.config["general"]["feature_type"] == "melspec":
enc_feats_aux = batch["melspecsaux"]
elif self.config["general"]["feature_type"] == "vocfeats":
enc_feats_aux = torch.cat(
(batch["f0s"].unsqueeze(1), batch["melceps"]), dim=1
)
recons_wav = self.vocoder(enc_feats_aux).squeeze(1)
else:
gt_wav = None
recons_wav = None
return {
"reconstructed": recons_wav,
"remastered": remas,
"channeled": wavsdeg,
"input": batch["wavs"],
"input_recons": input_recons,
"groundtruth": gt_wav,
}
def test_epoch_end(self, outputs):
wav_dir = (
pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_wavs"
)
os.makedirs(wav_dir, exist_ok=True)
mel_dir = (
pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_mels"
)
os.makedirs(mel_dir, exist_ok=True)
print("Saving mel spectrogram plots ...")
for idx, out in enumerate(tqdm.tqdm(outputs)):
plot_keys = []
for key in [
"reconstructed",
"remastered",
"channeled",
"input",
"input_recons",
"groundtruth",
]:
if out[key] != None:
plot_keys.append(key)
torchaudio.save(
wav_dir / "{}-{}.wav".format(idx, key),
out[key][0, ...].unsqueeze(0).cpu(),
sample_rate=self.config["preprocess"]["sampling_rate"],
channels_first=True,
)
class SSLStepLightningModule(SSLBaseModule):
def __init__(self, config):
super().__init__(config)
if config["train"]["fix_channel"]:
for param in self.channel.parameters():
param.requires_grad = False
def training_step(self, batch, batch_idx, optimizer_idx):
if self.config["general"]["use_gst"]:
enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
chfeats = self.gst(batch["melspecs"].transpose(1, 2))
else:
enc_out, enc_hidden = self.encoder(
batch["melspecs"].unsqueeze(1).transpose(2, 3)
)
chfeats = self.channelfeats(enc_hidden)
enc_out = enc_out.squeeze(1).transpose(1, 2)
if self.config["general"]["feature_type"] == "melspec":
enc_feats = enc_out
elif self.config["general"]["feature_type"] == "vocfeats":
enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
remas = self.vocoder(enc_feats).squeeze(1)
wavsdeg = self.channel(remas, chfeats)
loss = self.criteria(wavsdeg, batch["wavs"])
self.log(
"train_loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss
def validation_step(self, batch, batch_idx):
if self.config["general"]["use_gst"]:
enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
chfeats = self.gst(batch["melspecs"].transpose(1, 2))
else:
enc_out, enc_hidden = self.encoder(
batch["melspecs"].unsqueeze(1).transpose(2, 3)
)
chfeats = self.channelfeats(enc_hidden)
enc_out = enc_out.squeeze(1).transpose(1, 2)
if self.config["general"]["feature_type"] == "melspec":
enc_feats = enc_out
feats_name = "melspec"
elif self.config["general"]["feature_type"] == "vocfeats":
enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
feats_name = "melcep"
remas = self.vocoder(enc_feats).squeeze(1)
wavsdeg = self.channel(remas, chfeats)
loss = self.criteria(wavsdeg, batch["wavs"])
logger_img_dict = {
"val_src_melspec": batch["melspecs"],
"val_pred_{}".format(feats_name): enc_out,
}
for auxfeats in ["melceps", "melspecsaux"]:
if auxfeats in batch:
logger_img_dict["val_aux_{}".format(auxfeats)] = batch[auxfeats]
logger_wav_dict = {
"val_src_wav": batch["wavs"],
"val_remastered_wav": remas,
"val_pred_wav": wavsdeg,
}
if "wavsaux" in batch:
logger_wav_dict["val_aux_wav"] = batch["wavsaux"]
d_out = {"val_loss": loss, "logger_dict": [logger_img_dict, logger_wav_dict]}
return d_out
def validation_epoch_end(self, outputs):
self.log(
"val_loss",
torch.stack([out["val_loss"] for out in outputs]).mean().item(),
on_epoch=True,
prog_bar=True,
logger=True,
)
def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_idx,
optimizer_closure,
on_tpu=False,
using_native_amp=False,
using_lbfgs=False,
):
if epoch < self.config["train"]["epoch_channel"]:
if optimizer_idx == 0:
optimizer.step(closure=optimizer_closure)
elif optimizer_idx == 1:
optimizer_closure()
else:
if optimizer_idx == 0:
optimizer_closure()
elif optimizer_idx == 1:
optimizer.step(closure=optimizer_closure)
def configure_optimizers(self):
if self.config["train"]["fix_channel"]:
if self.config["general"]["use_gst"]:
optimizer_channel = torch.optim.Adam(
self.gst.parameters(), lr=self.config["train"]["learning_rate"]
)
else:
optimizer_channel = torch.optim.Adam(
self.channelfeats.parameters(),
lr=self.config["train"]["learning_rate"],
)
optimizer_encoder = torch.optim.Adam(
self.encoder.parameters(), lr=self.config["train"]["learning_rate"]
)
else:
if self.config["general"]["use_gst"]:
optimizer_channel = torch.optim.Adam(
[
{"params": self.channel.parameters()},
{"params": self.gst.parameters()},
],
lr=self.config["train"]["learning_rate"],
)
else:
optimizer_channel = torch.optim.Adam(
[
{"params": self.channel.parameters()},
{"params": self.channelfeats.parameters()},
],
lr=self.config["train"]["learning_rate"],
)
optimizer_encoder = torch.optim.Adam(
self.encoder.parameters(), lr=self.config["train"]["learning_rate"]
)
optimizers = [optimizer_channel, optimizer_encoder]
schedulers = [
{
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizers[0], mode="min", factor=0.5, min_lr=1e-5, verbose=True
),
"interval": "epoch",
"frequency": 3,
"monitor": "val_loss",
},
{
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizers[1], mode="min", factor=0.5, min_lr=1e-5, verbose=True
),
"interval": "epoch",
"frequency": 3,
"monitor": "val_loss",
},
]
return optimizers, schedulers
def get_loss_function(self, config):
return MultiScaleSpectralLoss(config)
class SSLDualLightningModule(SSLBaseModule):
def __init__(self, config):
super().__init__(config)
if config["train"]["fix_channel"]:
for param in self.channel.parameters():
param.requires_grad = False
self.spec_module = torchaudio.transforms.MelSpectrogram(
sample_rate=config["preprocess"]["sampling_rate"],
n_fft=config["preprocess"]["fft_length"],
win_length=config["preprocess"]["frame_length"],
hop_length=config["preprocess"]["frame_shift"],
f_min=config["preprocess"]["fmin"],
f_max=config["preprocess"]["fmax"],
n_mels=config["preprocess"]["n_mels"],
power=1,
center=True,
norm="slaney",
mel_scale="slaney",
)
self.beta = config["train"]["beta"]
self.criteria_a, self.criteria_b = self.get_loss_function(config)
def training_step(self, batch, batch_idx):
if self.config["general"]["use_gst"]:
enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
chfeats = self.gst(batch["melspecs"].transpose(1, 2))
else:
enc_out, enc_hidden = self.encoder(
batch["melspecs"].unsqueeze(1).transpose(2, 3)
)
chfeats = self.channelfeats(enc_hidden)
enc_out = enc_out.squeeze(1).transpose(1, 2)
if self.config["general"]["feature_type"] == "melspec":
enc_feats = enc_out
elif self.config["general"]["feature_type"] == "vocfeats":
enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
remas = self.vocoder(enc_feats).squeeze(1)
wavsdeg = self.channel(remas, chfeats)
loss_recons = self.criteria_a(wavsdeg, batch["wavs"])
with torch.no_grad():
wavsdegtask = self.channel(batch["wavstask"], chfeats)
melspecstask = self.calc_spectrogram(wavsdegtask)
if self.config["general"]["use_gst"]:
enc_out_task = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3))
else:
enc_out_task, _ = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3))
enc_out_task = enc_out_task.squeeze(1).transpose(1, 2)
if self.config["general"]["feature_type"] == "melspec":
loss_task = self.criteria_b(enc_out_task, batch["melspecstask"])
elif self.config["general"]["feature_type"] == "vocfeats":
loss_task = self.criteria_b(enc_out_task, batch["melcepstask"])
loss = self.beta * loss_recons + (1 - self.beta) * loss_task
self.log(
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
)
self.log(
"train_loss_recons",
loss_recons,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
self.log(
"train_loss_task",
loss_task,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss
def validation_step(self, batch, batch_idx):
if self.config["general"]["use_gst"]:
enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
chfeats = self.gst(batch["melspecs"].transpose(1, 2))
else:
enc_out, enc_hidden = self.encoder(
batch["melspecs"].unsqueeze(1).transpose(2, 3)
)
chfeats = self.channelfeats(enc_hidden)
enc_out = enc_out.squeeze(1).transpose(1, 2)
if self.config["general"]["feature_type"] == "melspec":
enc_feats = enc_out
feats_name = "melspec"
elif self.config["general"]["feature_type"] == "vocfeats":
enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
feats_name = "melcep"
remas = self.vocoder(enc_feats).squeeze(1)
wavsdeg = self.channel(remas, chfeats)
loss_recons = self.criteria_a(wavsdeg, batch["wavs"])
wavsdegtask = self.channel(batch["wavstask"], chfeats)
melspecstask = self.calc_spectrogram(wavsdegtask)
if self.config["general"]["use_gst"]:
enc_out_task = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3))
else:
enc_out_task, _ = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3))
enc_out_task = enc_out_task.squeeze(1).transpose(1, 2)
if self.config["general"]["feature_type"] == "melspec":
enc_out_task_truth = batch["melspecstask"]
loss_task = self.criteria_b(enc_out_task, enc_out_task_truth)
elif self.config["general"]["feature_type"] == "vocfeats":
enc_out_task_truth = batch["melcepstask"]
loss_task = self.criteria_b(enc_out_task, enc_out_task_truth)
loss = self.beta * loss_recons + (1 - self.beta) * loss_task
logger_img_dict = {
"val_src_melspec": batch["melspecs"],
"val_pred_{}".format(feats_name): enc_out,
"val_truth_{}_task".format(feats_name): enc_out_task_truth,
"val_pred_{}_task".format(feats_name): enc_out_task,
}
for auxfeats in ["melceps", "melspecsaux"]:
if auxfeats in batch:
logger_img_dict["val_aux_{}".format(auxfeats)] = batch[auxfeats]
logger_wav_dict = {
"val_src_wav": batch["wavs"],
"val_remastered_wav": remas,
"val_pred_wav": wavsdeg,
"val_truth_wavtask": batch["wavstask"],
"val_deg_wavtask": wavsdegtask,
}
if "wavsaux" in batch:
logger_wav_dict["val_aux_wav"] = batch["wavsaux"]
d_out = {
"val_loss": loss,
"val_loss_recons": loss_recons,
"val_loss_task": loss_task,
"logger_dict": [logger_img_dict, logger_wav_dict],
}
return d_out
def validation_epoch_end(self, outputs):
self.log(
"val_loss",
torch.stack([out["val_loss"] for out in outputs]).mean().item(),
on_epoch=True,
prog_bar=True,
logger=True,
)
self.log(
"val_loss_recons",
torch.stack([out["val_loss_recons"] for out in outputs]).mean().item(),
on_epoch=True,
prog_bar=True,
logger=True,
)
self.log(
"val_loss_task",
torch.stack([out["val_loss_task"] for out in outputs]).mean().item(),
on_epoch=True,
prog_bar=True,
logger=True,
)
def test_step(self, batch, batch_idx):
if self.config["general"]["use_gst"]:
enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
chfeats = self.gst(batch["melspecs"].transpose(1, 2))
else:
enc_out, enc_hidden = self.encoder(
batch["melspecs"].unsqueeze(1).transpose(2, 3)
)
chfeats = self.channelfeats(enc_hidden)
enc_out = enc_out.squeeze(1).transpose(1, 2)
if self.config["general"]["feature_type"] == "melspec":
enc_feats = enc_out
elif self.config["general"]["feature_type"] == "vocfeats":
enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
remas = self.vocoder(enc_feats).squeeze(1)
wavsdeg = self.channel(remas, chfeats)
if self.config["general"]["feature_type"] == "melspec":
enc_feats_input = batch["melspecs"]
elif self.config["general"]["feature_type"] == "vocfeats":
enc_feats_input = torch.cat(
(batch["f0s"].unsqueeze(1), batch["melcepssrc"]), dim=1
)
input_recons = self.vocoder(enc_feats_input).squeeze(1)
wavsdegtask = self.channel(batch["wavstask"], chfeats)
if "wavsaux" in batch:
gt_wav = batch["wavsaux"]
if self.config["general"]["feature_type"] == "melspec":
enc_feats_aux = batch["melspecsaux"]
elif self.config["general"]["feature_type"] == "vocfeats":
enc_feats_aux = torch.cat(
(batch["f0s"].unsqueeze(1), batch["melceps"]), dim=1
)
recons_wav = self.vocoder(enc_feats_aux).squeeze(1)
else:
gt_wav = None
recons_wav = None
return {
"reconstructed": recons_wav,
"remastered": remas,
"channeled": wavsdeg,
"channeled_task": wavsdegtask,
"input": batch["wavs"],
"input_recons": input_recons,
"groundtruth": gt_wav,
}
def test_epoch_end(self, outputs):
wav_dir = (
pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_wavs"
)
os.makedirs(wav_dir, exist_ok=True)
mel_dir = (
pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_mels"
)
os.makedirs(mel_dir, exist_ok=True)
print("Saving mel spectrogram plots ...")
for idx, out in enumerate(tqdm.tqdm(outputs)):
plot_keys = []
for key in [
"reconstructed",
"remastered",
"channeled",
"channeled_task",
"input",
"input_recons",
"groundtruth",
]:
if out[key] != None:
plot_keys.append(key)
torchaudio.save(
wav_dir / "{}-{}.wav".format(idx, key),
out[key][0, ...].unsqueeze(0).cpu(),
sample_rate=self.config["preprocess"]["sampling_rate"],
channels_first=True,
)
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(), lr=self.config["train"]["learning_rate"]
)
lr_scheduler_config = {
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, min_lr=1e-5, verbose=True
),
"interval": "epoch",
"frequency": 3,
"monitor": "val_loss",
}
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
def calc_spectrogram(self, wav):
specs = self.spec_module(wav)
log_spec = torch.log(
torch.clamp_min(specs, self.config["preprocess"]["min_magnitude"])
* self.config["preprocess"]["comp_factor"]
).to(torch.float32)
return log_spec
def get_loss_function(self, config):
if config["train"]["feature_loss"]["type"] == "mae":
feature_loss = torch.nn.L1Loss()
else:
feature_loss = torch.nn.MSELoss()
return MultiScaleSpectralLoss(config), feature_loss