File size: 3,358 Bytes
d4c980e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
import numpy as np
import scipy.stats
from scipy.signal import butter, sosfilt
import torch
from pesq import pesq
from pystoi import stoi
def si_sdr_components(s_hat, s, n):
"""
"""
# s_target
alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2
s_target = alpha_s * s
# e_noise
alpha_n = np.dot(s_hat, n) / np.linalg.norm(n)**2
e_noise = alpha_n * n
# e_art
e_art = s_hat - s_target - e_noise
return s_target, e_noise, e_art
def energy_ratios(s_hat, s, n):
"""
"""
s_target, e_noise, e_art = si_sdr_components(s_hat, s, n)
si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2)
si_sir = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise)**2)
si_sar = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_art)**2)
return si_sdr, si_sir, si_sar
def mean_conf_int(data, confidence=0.95):
a = 1.0 * np.array(data)
n = len(a)
m, se = np.mean(a), scipy.stats.sem(a)
h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
return m, h
class Method():
def __init__(self, name, base_dir, metrics):
self.name = name
self.base_dir = base_dir
self.metrics = {}
for i in range(len(metrics)):
metric = metrics[i]
value = []
self.metrics[metric] = value
def append(self, matric, value):
self.metrics[matric].append(value)
def get_mean_ci(self, metric):
return mean_conf_int(np.array(self.metrics[metric]))
def hp_filter(signal, cut_off=80, order=10, sr=16000):
factor = cut_off /sr * 2
sos = butter(order, factor, 'hp', output='sos')
filtered = sosfilt(sos, signal)
return filtered
def si_sdr(s, s_hat):
alpha = np.dot(s_hat, s)/np.linalg.norm(s)**2
sdr = 10*np.log10(np.linalg.norm(alpha*s)**2/np.linalg.norm(
alpha*s - s_hat)**2)
return sdr
def snr_dB(s,n):
s_power = 1/len(s)*np.sum(s**2)
n_power = 1/len(n)*np.sum(n**2)
snr_dB = 10*np.log10(s_power/n_power)
return snr_dB
def pad_spec(Y):
T = Y.size(3)
if T%64 !=0:
num_pad = 64-T%64
else:
num_pad = 0
pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0))
return pad2d(Y)
def ensure_dir(file_path):
directory = file_path
if not os.path.exists(directory):
os.makedirs(directory)
def print_metrics(x, y, x_hat_list, labels, sr=16000):
_si_sdr_mix = si_sdr(x, y)
_pesq_mix = pesq(sr, x, y, 'wb')
_estoi_mix = stoi(x, y, sr, extended=True)
print(f'Mixture: PESQ: {_pesq_mix:.2f}, ESTOI: {_estoi_mix:.2f}, SI-SDR: {_si_sdr_mix:.2f}')
for i, x_hat in enumerate(x_hat_list):
_si_sdr = si_sdr(x, x_hat)
_pesq = pesq(sr, x, x_hat, 'wb')
_estoi = stoi(x, x_hat, sr, extended=True)
print(f'{labels[i]}: {_pesq:.2f}, ESTOI: {_estoi:.2f}, SI-SDR: {_si_sdr:.2f}')
def mean_std(data):
data = data[~np.isnan(data)]
mean = np.mean(data)
std = np.std(data)
return mean, std
def print_mean_std(data, decimal=2):
data = np.array(data)
data = data[~np.isnan(data)]
mean = np.mean(data)
std = np.std(data)
if decimal == 2:
string = f'{mean:.2f} ± {std:.2f}'
elif decimal == 1:
string = f'{mean:.1f} ± {std:.1f}'
return string
|