ernestchu
commited on
Commit
·
32bac05
1
Parent(s):
19b2e5e
init
Browse files- .gitignore +7 -0
- app.py +117 -0
- packages.txt +1 -0
- requirements.txt +7 -0
- tsmnet/setup.py +7 -0
- tsmnet/tsmnet/__init__.py +1 -0
- tsmnet/tsmnet/dataset.py +79 -0
- tsmnet/tsmnet/interface.py +80 -0
- tsmnet/tsmnet/modules.py +180 -0
- tsmnet/tsmnet/utils.py +13 -0
- weights/args.yml +24 -0
- weights/classical-music.pt +3 -0
- weights/general.pt +3 -0
- weights/pop-music.pt +3 -0
- weights/speech.pt +3 -0
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
flagged
|
2 |
+
__pycache__
|
3 |
+
.DS_Store
|
4 |
+
*.swp
|
5 |
+
*.egg-info
|
6 |
+
build
|
7 |
+
|
app.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tsmnet import Stretcher
|
3 |
+
import gradio as gr
|
4 |
+
from gradio import processing_utils
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
|
8 |
+
model_root = './weights'
|
9 |
+
available_models = ['general', 'pop-music', 'classical-music', 'speech']
|
10 |
+
working_sr = 22050
|
11 |
+
|
12 |
+
def prepare_models():
|
13 |
+
return {
|
14 |
+
weight: Stretcher(os.path.join(model_root, f'{weight}.pt'))
|
15 |
+
for weight in available_models
|
16 |
+
}
|
17 |
+
|
18 |
+
def prepare_audio_file(rec, audio_file, yt_url):
|
19 |
+
if rec is not None:
|
20 |
+
return rec
|
21 |
+
if audio_file is not None:
|
22 |
+
return audio_file
|
23 |
+
if yt_url != '':
|
24 |
+
pass
|
25 |
+
else:
|
26 |
+
raise gr.Error('No audio found!')
|
27 |
+
|
28 |
+
|
29 |
+
def run(rec, audio_file, yt_url, speed, model, start_time, end_time):
|
30 |
+
audio_file = prepare_audio_file(rec, audio_file, yt_url)
|
31 |
+
if speed == 1:
|
32 |
+
return processing_utils.audio_from_file(audio_file)
|
33 |
+
|
34 |
+
model = models[model]
|
35 |
+
|
36 |
+
x, sr = torchaudio.load(audio_file)
|
37 |
+
x = torchaudio.transforms.Resample(orig_freq=sr, new_freq=working_sr)(x)
|
38 |
+
sr = working_sr
|
39 |
+
|
40 |
+
x = model(x, speed).cpu()
|
41 |
+
|
42 |
+
torchaudio.save(audio_file, x, sr)
|
43 |
+
|
44 |
+
return processing_utils.audio_from_file(audio_file)
|
45 |
+
|
46 |
+
|
47 |
+
# @@@@@@@ Start of the program @@@@@@@@
|
48 |
+
|
49 |
+
models = prepare_models()
|
50 |
+
|
51 |
+
with gr.Blocks() as demo:
|
52 |
+
gr.Markdown('# TSM-Net')
|
53 |
+
gr.Markdown('---')
|
54 |
+
with gr.Row():
|
55 |
+
with gr.Column():
|
56 |
+
with gr.Tab('From microphone'):
|
57 |
+
rec_box = gr.Audio(label='Recording', source='microphone', type='filepath')
|
58 |
+
with gr.Tab('From file'):
|
59 |
+
audio_file_box = gr.Audio(label='Audio sample', type='filepath')
|
60 |
+
with gr.Tab('From YouTube'):
|
61 |
+
yt_url_box = gr.Textbox(label='YouTube URL', placeholder='Under Construction', interactive=False)
|
62 |
+
|
63 |
+
rec_box.change(lambda: [None] * 2, outputs=[audio_file_box, yt_url_box])
|
64 |
+
audio_file_box.change(lambda: [None] * 2, outputs=[rec_box, yt_url_box])
|
65 |
+
yt_url_box.input(lambda: [None] * 2, outputs=[rec_box, audio_file_box])
|
66 |
+
|
67 |
+
speed_box = gr.Slider(label='Playback speed', minimum=0, maximum=2, value=1)
|
68 |
+
with gr.Accordion('Fine-grained settings', open=False):
|
69 |
+
with gr.Row():
|
70 |
+
gr.Textbox(label='', value='Trim audio sample', interactive=False)
|
71 |
+
start_time_box = gr.Number(label='Start', value=0)
|
72 |
+
end_time_box = gr.Number(label='End', value=20)
|
73 |
+
|
74 |
+
model_box = gr.Dropdown(label='Model weight', choices=available_models, value=available_models[0])
|
75 |
+
|
76 |
+
submit_btn = gr.Button('Submit')
|
77 |
+
|
78 |
+
with gr.Column():
|
79 |
+
with gr.Accordion('Hint', open=False):
|
80 |
+
gr.Markdown('You can find more settings under the **Fine-grained settings**')
|
81 |
+
gr.Markdown('- Feeling slow? Try to adjust the start/end timestamp')
|
82 |
+
gr.Markdown('- Low audio quality? Try to switch to a proper model weight')
|
83 |
+
outputs=gr.Audio(label='Output')
|
84 |
+
|
85 |
+
submit_btn.click(fn=run, inputs=[
|
86 |
+
rec_box,
|
87 |
+
audio_file_box,
|
88 |
+
yt_url_box,
|
89 |
+
speed_box,
|
90 |
+
model_box,
|
91 |
+
start_time_box,
|
92 |
+
end_time_box,
|
93 |
+
], outputs=outputs)
|
94 |
+
|
95 |
+
with gr.Accordion('Read more ...', open=False):
|
96 |
+
gr.Markdown('---')
|
97 |
+
gr.Markdown(
|
98 |
+
'We proposed a novel approach in the field of time-scale modification '
|
99 |
+
'on audio signals. While traditional methods use the framing technique, '
|
100 |
+
'spectral approach uses the short-time Fourier transform to preserve '
|
101 |
+
'the frequency during temporal stretching. TSM-Net, our neural-network '
|
102 |
+
'model encodes the raw audio into a high-level latent representation. '
|
103 |
+
'We call it Neuralgram, in which one vector represents 1024 audio samples. '
|
104 |
+
'It is inspired by the framing technique but addresses the clipping '
|
105 |
+
'artifacts. The Neuralgram is a two-dimensional matrix with real values, '
|
106 |
+
'we can apply some existing image resizing techniques on the Neuralgram '
|
107 |
+
'and decode it using our neural decoder to obtain the time-scaled audio. '
|
108 |
+
'Both the encoder and decoder are trained with GANs, which shows fair '
|
109 |
+
'generalization ability on the scaled Neuralgrams. Our method yields '
|
110 |
+
'little artifacts and opens a new possibility in the research of modern '
|
111 |
+
'time-scale modification. Please find more detail in our '
|
112 |
+
'<a href="https://arxiv.org/abs/2210.17152" target="_blank">paper</a>.'
|
113 |
+
)
|
114 |
+
|
115 |
+
demo.queue(4)
|
116 |
+
demo.launch(server_name='0.0.0.0')
|
117 |
+
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
ffmpeg
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
./tsmnet
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
torchaudio
|
5 |
+
yt-dlp
|
6 |
+
wget
|
7 |
+
|
tsmnet/setup.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name='tsmnet',
|
5 |
+
version='1.0.0',
|
6 |
+
packages=['tsmnet'],
|
7 |
+
)
|
tsmnet/tsmnet/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from tsmnet.interface import load_model, Neuralgram, Stretcher
|
tsmnet/tsmnet/dataset.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.utils.data
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchaudio
|
5 |
+
|
6 |
+
|
7 |
+
from pathlib import Path
|
8 |
+
import numpy as np
|
9 |
+
import random
|
10 |
+
|
11 |
+
|
12 |
+
def files_to_list(filename):
|
13 |
+
"""
|
14 |
+
Takes a text file of filenames and makes a list of filenames
|
15 |
+
"""
|
16 |
+
with open(filename, encoding="utf-8") as f:
|
17 |
+
files = f.readlines()
|
18 |
+
|
19 |
+
files = [f.rstrip() for f in files]
|
20 |
+
return files
|
21 |
+
|
22 |
+
|
23 |
+
class AudioDataset(torch.utils.data.Dataset):
|
24 |
+
"""
|
25 |
+
This is the main class that calculates the spectrogram and returns the
|
26 |
+
spectrogram, audio pair.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, training_files, segment_length, sampling_rate, augment=True):
|
30 |
+
self.sampling_rate = sampling_rate
|
31 |
+
self.segment_length = segment_length
|
32 |
+
self.audio_files = files_to_list(training_files)
|
33 |
+
self.audio_files = [Path(training_files).parent / x for x in self.audio_files]
|
34 |
+
random.seed(1234)
|
35 |
+
random.shuffle(self.audio_files)
|
36 |
+
self.augment = augment
|
37 |
+
|
38 |
+
def __getitem__(self, index):
|
39 |
+
# Read audio
|
40 |
+
filename = self.audio_files[index]
|
41 |
+
try:
|
42 |
+
audio, sampling_rate = self.load_wav_to_torch(filename)
|
43 |
+
except RuntimeError:
|
44 |
+
# there's lots of corrupted files in FMA
|
45 |
+
print(f'Found corrupted file: {filename}, use empty data instead')
|
46 |
+
audio = torch.tensor([])
|
47 |
+
# Take segment
|
48 |
+
if audio.size(0) >= self.segment_length:
|
49 |
+
max_audio_start = audio.size(0) - self.segment_length
|
50 |
+
audio_start = random.randint(0, max_audio_start)
|
51 |
+
audio = audio[audio_start : audio_start + self.segment_length]
|
52 |
+
else:
|
53 |
+
audio = F.pad(
|
54 |
+
audio, (0, self.segment_length - audio.size(0)), "constant"
|
55 |
+
).data
|
56 |
+
|
57 |
+
# audio = audio / 32768.0
|
58 |
+
return audio.unsqueeze(0)
|
59 |
+
|
60 |
+
def __len__(self):
|
61 |
+
return len(self.audio_files)
|
62 |
+
|
63 |
+
def load_wav_to_torch(self, full_path):
|
64 |
+
"""
|
65 |
+
Loads audio into torch array
|
66 |
+
"""
|
67 |
+
data, sampling_rate = torchaudio.load(str(full_path))
|
68 |
+
data = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=self.sampling_rate)(data)
|
69 |
+
sampling_rate = self.sampling_rate
|
70 |
+
|
71 |
+
if len(data.shape) > 1:
|
72 |
+
# convert to mono
|
73 |
+
data = data[random.randint(0, data.shape[0]-1)]
|
74 |
+
|
75 |
+
if self.augment:
|
76 |
+
amplitude = np.random.uniform(low=0.3, high=1.0)
|
77 |
+
data = data * amplitude
|
78 |
+
|
79 |
+
return data.float(), sampling_rate
|
tsmnet/tsmnet/interface.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tsmnet.modules import Autoencoder
|
2 |
+
|
3 |
+
from torchvision.transforms.functional import resize
|
4 |
+
from torchvision.transforms import InterpolationMode
|
5 |
+
from pathlib import Path
|
6 |
+
import yaml
|
7 |
+
import torch
|
8 |
+
import os
|
9 |
+
|
10 |
+
|
11 |
+
def get_default_device():
|
12 |
+
if torch.cuda.is_available():
|
13 |
+
return "cuda"
|
14 |
+
else:
|
15 |
+
return "cpu"
|
16 |
+
|
17 |
+
|
18 |
+
def load_model(path, device=get_default_device()):
|
19 |
+
"""
|
20 |
+
Args:
|
21 |
+
mel2wav_path (str or Path): path to the root folder of dumped text2mel
|
22 |
+
device (str or torch.device): device to load the model
|
23 |
+
"""
|
24 |
+
root = Path(path)
|
25 |
+
with open(os.path.join(os.path.dirname(path), "args.yml"), "r") as f:
|
26 |
+
args = yaml.unsafe_load(f)
|
27 |
+
netA = Autoencoder([int(n) for n in args.compress_ratios], args.ngf, args.n_residual_layers).to(device)
|
28 |
+
netA.load_state_dict(torch.load(path, map_location=device))
|
29 |
+
return netA
|
30 |
+
|
31 |
+
|
32 |
+
class Neuralgram:
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
path,
|
36 |
+
device=None,
|
37 |
+
):
|
38 |
+
if device is None:
|
39 |
+
device = get_default_device()
|
40 |
+
self.device = device
|
41 |
+
self.netA = load_model(path, device)
|
42 |
+
|
43 |
+
def __call__(self, audio):
|
44 |
+
"""
|
45 |
+
Performs audio to neuralgram conversion (See Autoencoder.encoder in tsmnet/modules.py)
|
46 |
+
Args:
|
47 |
+
audio (torch.tensor): PyTorch tensor containing audio (batch_size, timesteps)
|
48 |
+
Returns:
|
49 |
+
torch.tensor: neuralgram computed on input audio (batch_size, channels, timesteps)
|
50 |
+
"""
|
51 |
+
with torch.no_grad():
|
52 |
+
return self.netA.encoder(torch.as_tensor(audio).unsqueeze(1).to(self.device))
|
53 |
+
|
54 |
+
def inverse(self, neu):
|
55 |
+
"""
|
56 |
+
Performs neuralgram to audio conversion
|
57 |
+
Args:
|
58 |
+
neu (torch.tensor): PyTorch tensor containing neuralgram (batch_size, channels, timesteps)
|
59 |
+
Returns:
|
60 |
+
torch.tensor: Inverted raw audio (batch_size, timesteps)
|
61 |
+
|
62 |
+
"""
|
63 |
+
with torch.no_grad():
|
64 |
+
return self.netA.decoder(neu.to(self.device)).squeeze(1)
|
65 |
+
|
66 |
+
class Stretcher:
|
67 |
+
def __init__(self, path, device=None):
|
68 |
+
self.neuralgram = Neuralgram(path, device)
|
69 |
+
|
70 |
+
@torch.no_grad()
|
71 |
+
def __call__(self, audio, rate , interpolation=InterpolationMode.NEAREST): # NEAREST | BILINEAR | BICUBIC
|
72 |
+
if rate == 1:
|
73 |
+
return audio.numpy() if isinstance(audio, torch.Tensor) else audio
|
74 |
+
neu = self.neuralgram(audio)
|
75 |
+
neu_resized = resize(
|
76 |
+
neu,
|
77 |
+
(*neu.shape[1:-1], int(neu.shape[-1] * (1/rate))),
|
78 |
+
interpolation
|
79 |
+
)
|
80 |
+
return self.neuralgram.inverse(neu_resized)
|
tsmnet/tsmnet/modules.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch
|
4 |
+
from torch.nn.utils import weight_norm
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def weights_init(m):
|
9 |
+
classname = m.__class__.__name__
|
10 |
+
if classname.find("Conv") != -1:
|
11 |
+
m.weight.data.normal_(0.0, 0.02)
|
12 |
+
elif classname.find("BatchNorm2d") != -1:
|
13 |
+
m.weight.data.normal_(1.0, 0.02)
|
14 |
+
m.bias.data.fill_(0)
|
15 |
+
|
16 |
+
|
17 |
+
def WNConv1d(*args, **kwargs):
|
18 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
19 |
+
|
20 |
+
|
21 |
+
def WNConvTranspose1d(*args, **kwargs):
|
22 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
23 |
+
|
24 |
+
class ResnetBlock(nn.Module):
|
25 |
+
def __init__(self, dim, dilation=1):
|
26 |
+
super().__init__()
|
27 |
+
self.block = nn.Sequential(
|
28 |
+
nn.Tanh(),
|
29 |
+
nn.ReflectionPad1d(dilation),
|
30 |
+
WNConv1d(dim, dim, kernel_size=3, dilation=dilation),
|
31 |
+
nn.Tanh(),
|
32 |
+
WNConv1d(dim, dim, kernel_size=1),
|
33 |
+
)
|
34 |
+
self.shortcut = WNConv1d(dim, dim, kernel_size=1)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
return self.shortcut(x) + self.block(x)
|
38 |
+
|
39 |
+
class Autoencoder(nn.Module):
|
40 |
+
def __init__(self, compress_ratios, ngf, n_residual_layers):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
self.encoder = self.makeEncoder(compress_ratios, ngf, n_residual_layers)
|
44 |
+
self.decoder = self.makeDecoder([r for r in reversed(compress_ratios)], ngf, n_residual_layers)
|
45 |
+
|
46 |
+
self.apply(weights_init)
|
47 |
+
|
48 |
+
def makeEncoder(self, ratios, ngf, n_residual_layers):
|
49 |
+
mult = 1
|
50 |
+
|
51 |
+
model = [
|
52 |
+
nn.ReflectionPad1d(3),
|
53 |
+
WNConv1d(1, ngf, kernel_size=7, padding=0),
|
54 |
+
nn.Tanh(),
|
55 |
+
]
|
56 |
+
|
57 |
+
# Downsample to neuralgram scale
|
58 |
+
for i, r in enumerate(ratios):
|
59 |
+
mult *= 2
|
60 |
+
|
61 |
+
for j in range(n_residual_layers-1, -1, -1):
|
62 |
+
model += [ResnetBlock(mult * ngf // 2, dilation=3 ** j)]
|
63 |
+
|
64 |
+
model += [
|
65 |
+
nn.Tanh(),
|
66 |
+
WNConv1d(
|
67 |
+
mult * ngf // 2,
|
68 |
+
mult * ngf,
|
69 |
+
kernel_size=r * 2,
|
70 |
+
stride=r,
|
71 |
+
padding=r // 2 + r % 2
|
72 |
+
),
|
73 |
+
]
|
74 |
+
|
75 |
+
model += [ nn.Tanh() ]
|
76 |
+
|
77 |
+
return nn.Sequential(*model)
|
78 |
+
def makeDecoder(self, ratios, ngf, n_residual_layers):
|
79 |
+
mult = int(2 ** len(ratios))
|
80 |
+
|
81 |
+
model = []
|
82 |
+
|
83 |
+
# Upsample to raw audio scale
|
84 |
+
for i, r in enumerate(ratios):
|
85 |
+
model += [
|
86 |
+
nn.Tanh(),
|
87 |
+
WNConvTranspose1d(
|
88 |
+
mult * ngf,
|
89 |
+
mult * ngf // 2,
|
90 |
+
kernel_size=r * 2,
|
91 |
+
stride=r,
|
92 |
+
padding=r // 2 + r % 2,
|
93 |
+
output_padding=r % 2
|
94 |
+
),
|
95 |
+
]
|
96 |
+
|
97 |
+
for j in range(n_residual_layers):
|
98 |
+
model += [ResnetBlock(mult * ngf // 2, dilation=3 ** j)]
|
99 |
+
|
100 |
+
mult //= 2
|
101 |
+
|
102 |
+
model += [
|
103 |
+
nn.Tanh(),
|
104 |
+
nn.ReflectionPad1d(3),
|
105 |
+
WNConv1d(ngf, 1, kernel_size=7, padding=0),
|
106 |
+
nn.Tanh(),
|
107 |
+
]
|
108 |
+
|
109 |
+
return nn.Sequential(*model)
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
return self.decoder(self.encoder(x))
|
113 |
+
|
114 |
+
class NLayerDiscriminator(nn.Module):
|
115 |
+
def __init__(self, ndf, n_layers, downsampling_factor):
|
116 |
+
super().__init__()
|
117 |
+
model = nn.ModuleDict()
|
118 |
+
|
119 |
+
model["layer_0"] = nn.Sequential(
|
120 |
+
nn.ReflectionPad1d(7),
|
121 |
+
WNConv1d(1, ndf, kernel_size=15),
|
122 |
+
nn.Tanh(),
|
123 |
+
)
|
124 |
+
|
125 |
+
nf = ndf
|
126 |
+
stride = downsampling_factor
|
127 |
+
for n in range(1, n_layers + 1):
|
128 |
+
nf_prev = nf
|
129 |
+
nf = min(nf * stride, 1024)
|
130 |
+
|
131 |
+
model["layer_%d" % n] = nn.Sequential(
|
132 |
+
WNConv1d(
|
133 |
+
nf_prev,
|
134 |
+
nf,
|
135 |
+
kernel_size=stride * 10 + 1,
|
136 |
+
stride=stride,
|
137 |
+
padding=stride * 5,
|
138 |
+
groups=nf_prev // 4,
|
139 |
+
),
|
140 |
+
nn.Tanh(),
|
141 |
+
)
|
142 |
+
|
143 |
+
nf = min(nf * 2, 1024)
|
144 |
+
model["layer_%d" % (n_layers + 1)] = nn.Sequential(
|
145 |
+
WNConv1d(nf_prev, nf, kernel_size=5, stride=1, padding=2),
|
146 |
+
nn.Tanh(),
|
147 |
+
)
|
148 |
+
|
149 |
+
model["layer_%d" % (n_layers + 2)] = WNConv1d(
|
150 |
+
nf, 1, kernel_size=3, stride=1, padding=1
|
151 |
+
)
|
152 |
+
|
153 |
+
self.model = model
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
results = []
|
157 |
+
for key, layer in self.model.items():
|
158 |
+
x = layer(x)
|
159 |
+
results.append(x)
|
160 |
+
return results
|
161 |
+
|
162 |
+
|
163 |
+
class Discriminator(nn.Module):
|
164 |
+
def __init__(self, num_D, ndf, n_layers, downsampling_factor):
|
165 |
+
super().__init__()
|
166 |
+
self.model = nn.ModuleDict()
|
167 |
+
for i in range(num_D):
|
168 |
+
self.model[f"disc_{i}"] = NLayerDiscriminator(
|
169 |
+
ndf, n_layers, downsampling_factor
|
170 |
+
)
|
171 |
+
|
172 |
+
self.downsample = nn.AvgPool1d(4, stride=2, padding=1, count_include_pad=False)
|
173 |
+
self.apply(weights_init)
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
results = []
|
177 |
+
for key, disc in self.model.items():
|
178 |
+
results.append(disc(x))
|
179 |
+
x = self.downsample(x)
|
180 |
+
return results
|
tsmnet/tsmnet/utils.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import scipy.io.wavfile
|
2 |
+
|
3 |
+
|
4 |
+
def save_sample(file_path, sampling_rate, audio):
|
5 |
+
"""Helper function to save sample
|
6 |
+
|
7 |
+
Args:
|
8 |
+
file_path (str or pathlib.Path): save file path
|
9 |
+
sampling_rate (int): sampling rate of audio (usually 22050)
|
10 |
+
audio (torch.FloatTensor): torch array containing audio in [-1, 1]
|
11 |
+
"""
|
12 |
+
audio = (audio.numpy() * 32768).astype("int16")
|
13 |
+
scipy.io.wavfile.write(file_path, sampling_rate, audio)
|
weights/args.yml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
!!python/object:argparse.Namespace
|
2 |
+
batch_size: 2
|
3 |
+
compress_ratios: '22488'
|
4 |
+
cond_disc: false
|
5 |
+
data_path: !!python/object/apply:pathlib.PosixPath
|
6 |
+
- /
|
7 |
+
- home
|
8 |
+
- b073040018
|
9 |
+
- Datasets
|
10 |
+
downsamp_factor: 4
|
11 |
+
epochs: 3000
|
12 |
+
lambda_feat: 10
|
13 |
+
load_path: logs-all/weights
|
14 |
+
log_interval: 100
|
15 |
+
n_layers_D: 4
|
16 |
+
n_residual_layers: 1
|
17 |
+
n_test_samples: 8
|
18 |
+
ndf: 16
|
19 |
+
ngf: 32
|
20 |
+
num_D: 3
|
21 |
+
project: tsmnet-all
|
22 |
+
save_interval: 1000
|
23 |
+
save_path: logs-all2
|
24 |
+
seq_len: 8192
|
weights/classical-music.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c736e5c7414354ad2789b4d8dd6d3ab2d5813f52fb0982818a4fff8887d2eeba
|
3 |
+
size 100400811
|
weights/general.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e70b0ca672ab2008da3517ae3eb524135a1ef5685d59cc034084316a665f69f6
|
3 |
+
size 100400920
|
weights/pop-music.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f3010d34e0d538ecb4c63c8bc89ad4023630dc36e2746bb71b799026d2b03ad4
|
3 |
+
size 100400898
|
weights/speech.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e29674ce2312e1ba8f9071348de84031e8afbb08412cbc8088b7365f2162f497
|
3 |
+
size 100400879
|