akhaliq3
spaces demo
5019931
raw
history blame
4.39 kB
'''
@File : subband_util.py
@Contact : liu.8948@buckeyemail.osu.edu
@License : (C)Copyright 2020-2021
@Modify Time @Author @Version @Desciption
------------ ------- -------- -----------
2020/4/3 4:54 PM Haohe Liu 1.0 None
'''
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import os.path as op
from scipy.io import loadmat
def load_mat2numpy(fname=""):
'''
Args:
fname: pth to mat
type:
Returns: dic object
'''
if len(fname) == 0:
return None
else:
return loadmat(fname)
class PQMF(nn.Module):
def __init__(self, N, M, project_root):
super().__init__()
self.N = N # nsubband
self.M = M # nfilter
try:
assert (N, M) in [(8, 64), (4, 64), (2, 64)]
except:
print("Warning:", N, "subbandand ", M, " filter is not supported")
self.pad_samples = 64
self.name = str(N) + "_" + str(M) + ".mat"
self.ana_conv_filter = nn.Conv1d(
1, out_channels=N, kernel_size=M, stride=N, bias=False
)
data = load_mat2numpy(op.join(project_root, "f_" + self.name))
data = data['f'].astype(np.float32) / N
data = np.flipud(data.T).T
data = np.reshape(data, (N, 1, M)).copy()
dict_new = self.ana_conv_filter.state_dict().copy()
dict_new['weight'] = torch.from_numpy(data)
self.ana_pad = nn.ConstantPad1d((M - N, 0), 0)
self.ana_conv_filter.load_state_dict(dict_new)
self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0)
self.syn_conv_filter = nn.Conv1d(
N, out_channels=N, kernel_size=M // N, stride=1, bias=False
)
gk = load_mat2numpy(op.join(project_root, "h_" + self.name))
gk = gk['h'].astype(np.float32)
gk = np.transpose(np.reshape(gk, (N, M // N, N)), (1, 0, 2)) * N
gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy()
dict_new = self.syn_conv_filter.state_dict().copy()
dict_new['weight'] = torch.from_numpy(gk)
self.syn_conv_filter.load_state_dict(dict_new)
for param in self.parameters():
param.requires_grad = False
def __analysis_channel(self, inputs):
return self.ana_conv_filter(self.ana_pad(inputs))
def __systhesis_channel(self, inputs):
ret = self.syn_conv_filter(self.syn_pad(inputs)).permute(0, 2, 1)
return torch.reshape(ret, (ret.shape[0], 1, -1))
def analysis(self, inputs):
'''
:param inputs: [batchsize,channel,raw_wav],value:[0,1]
:return:
'''
inputs = F.pad(inputs, ((0, self.pad_samples)))
ret = None
for i in range(inputs.size()[1]): # channels
if ret is None:
ret = self.__analysis_channel(inputs[:, i : i + 1, :])
else:
ret = torch.cat(
(ret, self.__analysis_channel(inputs[:, i : i + 1, :])), dim=1
)
return ret
def synthesis(self, data):
'''
:param data: [batchsize,self.N*K,raw_wav_sub],value:[0,1]
:return:
'''
ret = None
# data = F.pad(data,((0,self.pad_samples//self.N)))
for i in range(data.size()[1]): # channels
if i % self.N == 0:
if ret is None:
ret = self.__systhesis_channel(data[:, i : i + self.N, :])
else:
new = self.__systhesis_channel(data[:, i : i + self.N, :])
ret = torch.cat((ret, new), dim=1)
ret = ret[..., : -self.pad_samples]
return ret
def forward(self, inputs):
return self.ana_conv_filter(self.ana_pad(inputs))
if __name__ == "__main__":
import torch
import numpy as np
import matplotlib.pyplot as plt
from tools.file.wav import *
pqmf = PQMF(N=4, M=64, project_root="/Users/admin/Documents/projects")
rs = np.random.RandomState(0)
x = torch.tensor(rs.rand(4, 2, 32000), dtype=torch.float32)
a1 = pqmf.analysis(x)
a2 = pqmf.synthesis(a1)
print(a2.size(), x.size())
plt.subplot(211)
plt.plot(x[0, 0, -500:])
plt.subplot(212)
plt.plot(a2[0, 0, -500:])
plt.plot(x[0, 0, -500:] - a2[0, 0, -500:])
plt.show()
print(torch.sum(torch.abs(x[...] - a2[...])))