Fckngproj / utils /tblogger.py
XDHDD's picture
Upload 8 files
e34c0af
raw
history blame
2.52 kB
from os import path
import librosa as rosa
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities import rank_zero_only
from utils.stft import STFTMag
matplotlib.use('Agg')
class TensorBoardLoggerExpanded(TensorBoardLogger):
def __init__(self, sr=16000):
super().__init__(save_dir='lightning_logs', default_hp_metric=False, name='')
self.sr = sr
self.stftmag = STFTMag()
def fig2np(self, fig):
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data
def plot_spectrogram_to_numpy(self, y, y_low, y_recon, step):
name_list = ['y', 'y_low', 'y_recon']
fig = plt.figure(figsize=(9, 15))
fig.suptitle(f'Epoch_{step}')
for i, yy in enumerate([y, y_low, y_recon]):
if yy.dim() == 1:
yy = self.stftmag(yy)
ax = plt.subplot(3, 1, i + 1)
ax.set_title(name_list[i])
plt.imshow(rosa.amplitude_to_db(yy.numpy(),
ref=np.max, top_db=80.),
# vmin = -20,
vmax=0.,
aspect='auto',
origin='lower',
interpolation='none')
plt.colorbar()
plt.xlabel('Frames')
plt.ylabel('Channels')
plt.tight_layout()
fig.canvas.draw()
data = self.fig2np(fig)
plt.close()
return data
@rank_zero_only
def log_spectrogram(self, y, y_low, y_recon, epoch):
y, y_low, y_recon = y.detach().cpu(), y_low.detach().cpu(), y_recon.detach().cpu()
spec_img = self.plot_spectrogram_to_numpy(y, y_low, y_recon, epoch)
self.experiment.add_image(path.join(self.save_dir, 'result'),
spec_img,
epoch,
dataformats='HWC')
self.experiment.flush()
return
@rank_zero_only
def log_audio(self, y, y_low, y_recon, epoch):
y, y_low, y_recon = y.detach().cpu(), y_low.detach().cpu(), y_recon.detach().cpu(),
name_list = ['y', 'y_low', 'y_recon']
for n, yy in zip(name_list, [y, y_low, y_recon]):
self.experiment.add_audio(n, yy, epoch, self.sr)
self.experiment.flush()
return