aet_demo / aet.py
saeki
fix
7b918f7
raw
history blame
13.8 kB
import argparse
import pathlib
import yaml
import torch
import torchaudio
from torch.utils.data import DataLoader
import numpy as np
import random
import librosa
from dataset import Dataset
import pickle
from lightning_module import (
SSLStepLightningModule,
SSLDualLightningModule,
)
from utils import plot_and_save_mels
import os
import tqdm
class AETDataset(Dataset):
def __init__(self, filetxt, src_config, tar_config):
self.config = src_config
self.preprocessed_dir_src = pathlib.Path(
src_config["general"]["preprocessed_path"]
)
self.preprocessed_dir_tar = pathlib.Path(
tar_config["general"]["preprocessed_path"]
)
for item in [
"sampling_rate",
"fft_length",
"frame_length",
"frame_shift",
"fmin",
"fmax",
"n_mels",
]:
assert src_config["preprocess"][item] == tar_config["preprocess"][item]
self.spec_module = torchaudio.transforms.MelSpectrogram(
sample_rate=src_config["preprocess"]["sampling_rate"],
n_fft=src_config["preprocess"]["fft_length"],
win_length=src_config["preprocess"]["frame_length"],
hop_length=src_config["preprocess"]["frame_shift"],
f_min=src_config["preprocess"]["fmin"],
f_max=src_config["preprocess"]["fmax"],
n_mels=src_config["preprocess"]["n_mels"],
power=1,
center=True,
norm="slaney",
mel_scale="slaney",
)
with open(self.preprocessed_dir_src / filetxt, "r") as fr:
self.filelist_src = [pathlib.Path(path.strip("\n")) for path in fr]
with open(self.preprocessed_dir_tar / filetxt, "r") as fr:
self.filelist_tar = [pathlib.Path(path.strip("\n")) for path in fr]
self.d_out = {"src": {}, "tar": {}}
for item in ["wavs", "wavsaux"]:
self.d_out["src"][item] = []
self.d_out["tar"][item] = []
for swp in self.filelist_src:
if src_config["general"]["corpus_type"] == "single":
basename = str(swp.stem)
else:
basename = str(swp.parent.name) + "-" + str(swp.stem)
with open(
self.preprocessed_dir_src / "{}.pickle".format(basename), "rb"
) as fw:
d_preprocessed = pickle.load(fw)
for item in ["wavs", "wavsaux"]:
try:
self.d_out["src"][item].extend(d_preprocessed[item])
except:
pass
for twp in self.filelist_tar:
if tar_config["general"]["corpus_type"] == "single":
basename = str(twp.stem)
else:
basename = str(twp.parent.name) + "-" + str(twp.stem)
with open(
self.preprocessed_dir_tar / "{}.pickle".format(basename), "rb"
) as fw:
d_preprocessed = pickle.load(fw)
for item in ["wavs", "wavsaux"]:
try:
self.d_out["tar"][item].extend(d_preprocessed[item])
except:
pass
min_len = min(len(self.d_out["src"]["wavs"]), len(self.d_out["tar"]["wavs"]))
for spk in ["src", "tar"]:
for item in ["wavs", "wavsaux"]:
if self.d_out[spk][item] != None:
self.d_out[spk][item] = np.asarray(self.d_out[spk][item][:min_len])
def __len__(self):
return len(self.d_out["src"]["wavs"])
def __getitem__(self, idx):
d_batch = {}
for spk in ["src", "tar"]:
for item in ["wavs", "wavsaux"]:
if self.d_out[spk][item].size > 0:
d_batch["{}_{}".format(item, spk)] = torch.from_numpy(
self.d_out[spk][item][idx]
)
d_batch["{}_{}".format(item, spk)] = self.normalize_waveform(
d_batch["{}_{}".format(item, spk)], db=-3
)
d_batch["melspecs_src"] = self.calc_spectrogram(d_batch["wavs_src"])
return d_batch
class AETModule(torch.nn.Module):
"""
src: Dataset from which we extract the channel features
tar: Dataset to which the src channel features are added
"""
def __init__(self, args, chmatch_config, src_config, tar_config):
super().__init__()
if args.stage == "ssl-step":
LModule = SSLStepLightningModule
elif args.stage == "ssl-dual":
LModule = SSLDualLightningModule
else:
raise NotImplementedError()
src_model = LModule(src_config).load_from_checkpoint(
checkpoint_path=chmatch_config["general"]["source"]["ckpt_path"],
config=src_config,
)
self.src_config = src_config
self.encoder_src = src_model.encoder
if src_config["general"]["use_gst"]:
self.gst_src = src_model.gst
else:
self.channelfeats_src = src_model.channelfeats
self.channel_src = src_model.channel
def forward(self, melspecs_src, wavsaux_tar):
if self.src_config["general"]["use_gst"]:
chfeats_src = self.gst_src(melspecs_src.transpose(1, 2))
else:
_, enc_hidden_src = self.encoder_src(
melspecs_src.unsqueeze(1).transpose(2, 3)
)
chfeats_src = self.channelfeats_src(enc_hidden_src)
wavschmatch_tar = self.channel_src(wavsaux_tar, chfeats_src)
return wavschmatch_tar
def get_arg():
parser = argparse.ArgumentParser()
parser.add_argument("--stage", required=True, type=str)
parser.add_argument("--config_path", required=True, type=pathlib.Path)
parser.add_argument("--exist_src_aux", action="store_true")
parser.add_argument("--run_name", required=True, type=str)
return parser.parse_args()
def main(args, chmatch_config, device):
src_config = yaml.load(
open(chmatch_config["general"]["source"]["config_path"], "r"),
Loader=yaml.FullLoader,
)
tar_config = yaml.load(
open(chmatch_config["general"]["target"]["config_path"], "r"),
Loader=yaml.FullLoader,
)
output_path = pathlib.Path(chmatch_config["general"]["output_path"]) / args.run_name
dataset = AETDataset("test.txt", src_config, tar_config)
loader = DataLoader(dataset, batch_size=1, shuffle=False)
chmatch_module = AETModule(args, chmatch_config, src_config, tar_config).to(device)
if args.exist_src_aux:
char_vector = calc_deg_charactaristics(chmatch_config)
for idx, batch in enumerate(tqdm.tqdm(loader)):
melspecs_src = batch["melspecs_src"].to(device)
wavsdeg_src = batch["wavs_src"].to(device)
wavsaux_tar = batch["wavsaux_tar"].to(device)
if args.exist_src_aux:
wavsdegbaseline_tar = calc_deg_baseline(
batch["wavsaux_tar"], char_vector, tar_config
)
wavsdegbaseline_tar = normalize_waveform(wavsdegbaseline_tar, tar_config)
wavsdeg_tar = batch["wavs_tar"].to(device)
wavsmatch_tar = normalize_waveform(
chmatch_module(melspecs_src, wavsaux_tar).cpu().detach(), tar_config
)
torchaudio.save(
output_path / "test_wavs" / "{}-src_wavsdeg.wav".format(idx),
wavsdeg_src.cpu(),
src_config["preprocess"]["sampling_rate"],
)
torchaudio.save(
output_path / "test_wavs" / "{}-tar_wavsaux.wav".format(idx),
wavsaux_tar.cpu(),
tar_config["preprocess"]["sampling_rate"],
)
if args.exist_src_aux:
torchaudio.save(
output_path / "test_wavs" / "{}-tar_wavsdegbaseline.wav".format(idx),
wavsdegbaseline_tar.cpu(),
tar_config["preprocess"]["sampling_rate"],
)
torchaudio.save(
output_path / "test_wavs" / "{}-tar_wavsdeg.wav".format(idx),
wavsdeg_tar.cpu(),
tar_config["preprocess"]["sampling_rate"],
)
torchaudio.save(
output_path / "test_wavs" / "{}-tar_wavsmatch.wav".format(idx),
wavsmatch_tar.cpu(),
tar_config["preprocess"]["sampling_rate"],
)
plot_and_save_mels(
wavsdeg_src[0, ...].cpu().detach(),
output_path / "test_mels" / "{}-src_melsdeg.png".format(idx),
src_config,
)
plot_and_save_mels(
wavsaux_tar[0, ...].cpu().detach(),
output_path / "test_mels" / "{}-tar_melsaux.png".format(idx),
tar_config,
)
if args.exist_src_aux:
plot_and_save_mels(
wavsdegbaseline_tar[0, ...].cpu().detach(),
output_path / "test_mels" / "{}-tar_melsdegbaseline.png".format(idx),
tar_config,
)
plot_and_save_mels(
wavsdeg_tar[0, ...].cpu().detach(),
output_path / "test_mels" / "{}-tar_melsdeg.png".format(idx),
tar_config,
)
plot_and_save_mels(
wavsmatch_tar[0, ...].cpu().detach(),
output_path / "test_mels" / "{}-tar_melsmatch.png".format(idx),
tar_config,
)
def calc_deg_baseline(wav, char_vector, tar_config):
wav = wav[0, ...].cpu().detach().numpy()
spec = librosa.stft(
wav,
n_fft=tar_config["preprocess"]["fft_length"],
hop_length=tar_config["preprocess"]["frame_shift"],
win_length=tar_config["preprocess"]["frame_length"],
)
spec_converted = spec * char_vector.reshape(-1, 1)
wav_converted = librosa.istft(
spec_converted,
hop_length=tar_config["preprocess"]["frame_shift"],
win_length=tar_config["preprocess"]["frame_length"],
)
wav_converted = torch.from_numpy(wav_converted).to(torch.float32).unsqueeze(0)
return wav_converted
def calc_deg_charactaristics(chmatch_config):
src_config = yaml.load(
open(chmatch_config["general"]["source"]["config_path"], "r"),
Loader=yaml.FullLoader,
)
tar_config = yaml.load(
open(chmatch_config["general"]["target"]["config_path"], "r"),
Loader=yaml.FullLoader,
)
# configs
preprocessed_dir = pathlib.Path(src_config["general"]["preprocessed_path"])
n_train = src_config["preprocess"]["n_train"]
SR = src_config["preprocess"]["sampling_rate"]
os.makedirs(preprocessed_dir, exist_ok=True)
sourcepath = pathlib.Path(src_config["general"]["source_path"])
if src_config["general"]["corpus_type"] == "single":
fulllist = list(sourcepath.glob("*.wav"))
random.seed(0)
random.shuffle(fulllist)
train_filelist = fulllist[:n_train]
elif src_config["general"]["corpus_type"] == "multi-seen":
fulllist = list(sourcepath.glob("*/*.wav"))
random.seed(0)
random.shuffle(fulllist)
train_filelist = fulllist[:n_train]
elif src_config["general"]["corpus_type"] == "multi-unseen":
spk_list = list(set([x.parent for x in sourcepath.glob("*/*.wav")]))
train_filelist = []
random.seed(0)
random.shuffle(spk_list)
for i, spk in enumerate(spk_list):
sourcespkpath = sourcepath / spk
if i < n_train:
train_filelist.extend(list(sourcespkpath.glob("*.wav")))
else:
raise NotImplementedError(
"corpus_type specified in config.yaml should be {single, multi-seen, multi-unseen}"
)
specs_all = np.zeros((tar_config["preprocess"]["fft_length"] // 2 + 1, 1))
for wp in tqdm.tqdm(train_filelist):
wav, _ = librosa.load(wp, sr=SR)
spec = np.abs(
librosa.stft(
wav,
n_fft=src_config["preprocess"]["fft_length"],
hop_length=src_config["preprocess"]["frame_shift"],
win_length=src_config["preprocess"]["frame_length"],
)
)
auxpath = pathlib.Path(src_config["general"]["aux_path"])
if src_config["general"]["corpus_type"] == "single":
wav_aux, _ = librosa.load(auxpath / wp.name, sr=SR)
else:
wav_aux, _ = librosa.load(auxpath / wp.parent.name / wp.name, sr=SR)
spec_aux = np.abs(
librosa.stft(
wav_aux,
n_fft=src_config["preprocess"]["fft_length"],
hop_length=src_config["preprocess"]["frame_shift"],
win_length=src_config["preprocess"]["frame_length"],
)
)
min_len = min(spec.shape[1], spec_aux.shape[1])
spec_diff = spec[:, :min_len] / (spec_aux[:, :min_len] + 1e-10)
specs_all = np.hstack([specs_all, np.mean(spec_diff, axis=1).reshape(-1, 1)])
char_vector = np.mean(specs_all, axis=1)
char_vector = char_vector / (np.sum(char_vector) + 1e-10)
return char_vector
def normalize_waveform(wav, tar_config, db=-3):
wav, _ = torchaudio.sox_effects.apply_effects_tensor(
wav,
tar_config["preprocess"]["sampling_rate"],
[["norm", "{}".format(db)]],
)
return wav
if __name__ == "__main__":
args = get_arg()
chmatch_config = yaml.load(open(args.config_path, "r"), Loader=yaml.FullLoader)
output_path = pathlib.Path(chmatch_config["general"]["output_path"]) / args.run_name
os.makedirs(output_path, exist_ok=True)
os.makedirs(output_path / "test_wavs", exist_ok=True)
os.makedirs(output_path / "test_mels", exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
main(args, chmatch_config, device)