Nithya commited on
Commit
98eb218
·
1 Parent(s): 3752793

updated parent repo and restructured things

Browse files
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore DELETED
@@ -1 +0,0 @@
1
- src/__pycache__/
 
 
app.py CHANGED
@@ -1,91 +1,28 @@
1
  import spaces
2
- from gradio import Interface, Audio
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
- import subprocess
7
  import librosa
8
  import matplotlib.pyplot as plt
9
  import pandas as pd
10
  import os
11
  from functools import partial
12
  import gin
13
- import sys
14
- sys.path.append('./')
15
- from src.generate_utils import invert_pitch_read, load_pitch_model, load_audio_model
16
- import src.pitch_to_audio_utils as p2a
17
  import torchaudio
18
  from absl import app
19
  from torch.nn.functional import interpolate
20
- import pdb
21
  import logging
22
  import crepe
23
  from hmmlearn import hmm
24
- import time
25
  import soundfile as sf
 
26
 
27
  pitch_path = 'models/diffusion_pitch/'
28
- # pitch_path = '/network/scratch/n/nithya.shikarpur/checkpoints/pitch-diffusion/corrected-attention-v3/4833583'
29
  audio_path = 'models/pitch_to_audio/'
30
- # audio_path = '/network/scratch/n/nithya.shikarpur/checkpoints/pitch-diffusion/corrected-attention-v3/4835364'
31
- # db_path_audio = '/home/mila/n/nithya.shikarpur/scratch/pitch-diffusion/data/merged_data-finalest/cached-audio-pitch-16k'
32
-
33
- device = 'cuda'
34
-
35
- global_ind = -1
36
- global_audios = np.array([0.0])
37
- global_pitches = np.array([0])
38
- singer = 3
39
- audio_components = []
40
- preprocessed_primes = []
41
- selected_prime = None
42
-
43
-
44
-
45
- def make_prime_npz(prime):
46
- np.savez('./temp/prime.npz', concatenated_array=[[prime]])
47
-
48
- def load_pitch_fns():
49
- pitch_model, pitch_qt, _, pitch_task_fn = load_pitch_model(
50
- os.path.join(pitch_path, 'config.gin'),
51
- os.path.join(pitch_path, 'last.ckpt'),
52
- os.path.join(pitch_path, 'qt.joblib'),
53
- device=device
54
- )
55
- invert_pitch_fn = partial(
56
- invert_pitch_read,
57
- min_norm_pitch=gin.query_parameter('dataset.pitch_read_w_downsample.min_norm_pitch'),
58
- time_downsample=gin.query_parameter('dataset.pitch_read_w_downsample.time_downsample'),
59
- pitch_downsample=gin.query_parameter('dataset.pitch_read_w_downsample.pitch_downsample'),
60
- qt_transform=pitch_qt,
61
- min_clip=gin.query_parameter('dataset.pitch_read_w_downsample.min_clip'),
62
- max_clip=gin.query_parameter('dataset.pitch_read_w_downsample.max_clip')
63
- )
64
- return pitch_model, pitch_qt, pitch_task_fn, invert_pitch_fn
65
-
66
- def interpolate_pitch(pitch, audio_seq_len):
67
- pitch = interpolate(pitch, size=audio_seq_len, mode='linear')
68
- # plt.plot(pitch[0].squeeze(0).detach().cpu().numpy())
69
- # plt.savefig(f"./temp/interpolated_pitch.png")
70
- # plt.close()
71
- return pitch
72
-
73
- def load_audio_fns():
74
- ckpt = os.path.join(audio_path, 'last.ckpt')
75
- config = os.path.join(audio_path, 'config.gin')
76
- qt = os.path.join(audio_path, 'qt.joblib')
77
- # qt = '/home/mila/n/nithya.shikarpur/scratch/pitch-diffusion/data/merged_data-finalest/cached-audio-pitch-16k/qt.joblib'
78
-
79
- audio_model, audio_qt = load_audio_model(config, ckpt, qt, device=device)
80
- audio_seq_len = gin.query_parameter('%AUDIO_SEQ_LEN')
81
-
82
- invert_audio_fn = partial(
83
- p2a.normalized_mels_to_audio,
84
- qt=audio_qt,
85
- n_iter=200
86
- )
87
-
88
- return audio_model, audio_qt, audio_seq_len, invert_audio_fn
89
 
90
  def predict_voicing(confidence):
91
  # https://github.com/marl/crepe/pull/26
@@ -136,73 +73,67 @@ def extract_pitch(audio, unvoice=True, sr=16000, frame_shift_ms=10, log=True):
136
 
137
  return time, f0, confidence
138
 
139
- def generate_pitch(pitch, pitch_model, invert_pitch_fn, num_samples, num_steps, outfolder=None, processed_primes=None):
140
- noisy_pitch = torch.Tensor(pitch[:, :, :1200]).to(pitch_model.device) + (torch.normal(mean=0.0, std=0.4*torch.ones(( 1200)))).to(pitch_model.device)
141
- noisy_pitch = torch.clamp(noisy_pitch, -5.19, 5.19)
 
 
142
  samples = pitch_model.sample_sdedit(noisy_pitch, num_samples, num_steps)
143
- inverted_pitches = [invert_pitch_fn(samples.detach().cpu().numpy()[0])[0]]
144
 
145
- if outfolder is not None:
146
- os.makedirs(outfolder, exist_ok=True)
147
- # pdb.set_trace()
148
- for i, pitch in enumerate(inverted_pitches):
149
- flattened_pitch = pitch.flatten()
150
- pd.DataFrame({'f0': flattened_pitch}).to_csv(f"{outfolder}/{i}.csv", index=False)
151
- plt.plot(np.where(flattened_pitch == 0, np.nan, flattened_pitch))
152
- plt.savefig(f"{outfolder}/{i}.png")
153
- plt.close()
154
  return samples, inverted_pitches
155
 
156
- def generate_audio(audio_model, f0s, invert_audio_fn, outfolder, singers=[3], num_steps=100):
 
157
  singer_tensor = torch.tensor(np.repeat(singers, repeats=f0s.shape[0])).to(audio_model.device)
158
  samples, _, singers = audio_model.sample_cfg(f0s.shape[0], f0=f0s, num_steps=num_steps, singer=singer_tensor, strength=3)
159
  audio = invert_audio_fn(samples)
160
-
161
- if outfolder is not None:
162
- os.makedirs(outfolder, exist_ok=True)
163
- for i, a in enumerate(audio):
164
- logging.log(logging.INFO, f"Saving audio {i}")
165
- torchaudio.save(f"{outfolder}/{i}.wav", torch.tensor(a).detach().unsqueeze(0).cpu(), 16000)
166
  return audio
167
 
168
  @spaces.GPU(duration=120)
169
- def generate(pitch, num_samples=2, num_steps=100, singers=[3], outfolder='temp', audio_seq_len=750, pitch_qt=None ):
170
- global global_ind, audio_components
171
- global preprocessed_primes
172
- # pdb.set_trace()
173
  logging.log(logging.INFO, 'Generate function')
174
- pitch, inverted_pitch = generate_pitch(pitch, pitch_model, invert_pitch_fn, 1, 100, outfolder=outfolder, processed_primes=selected_prime if global_ind != 0 else None)
175
  if pitch_qt is not None:
 
176
  def undo_qt(x, min_clip=200):
177
  pitch= pitch_qt.inverse_transform(x.reshape(-1, 1)).reshape(1, -1)
178
  pitch = np.around(pitch) # round to nearest integer, done in preprocessing of pitch contour fed into model
179
  pitch[pitch < 200] = np.nan
180
  return pitch
181
  pitch = torch.tensor(np.array([undo_qt(x) for x in pitch.detach().cpu().numpy()])).to(pitch_model.device)
182
- interpolated_pitch = interpolate_pitch(pitch=pitch, audio_seq_len=audio_seq_len)
183
- interpolated_pitch = torch.nan_to_num(interpolated_pitch, nan=196)
184
  interpolated_pitch = interpolated_pitch.squeeze(1) # to match input size by removing the extra dimension
185
- audio = generate_audio(audio_model, interpolated_pitch, invert_audio_fn, singers=singers, num_steps=100, outfolder=outfolder)
186
- # pdb.set_trace()
187
- audio = audio.detach().cpu().numpy()[:, :]
188
  pitch = pitch.detach().cpu().numpy()
189
- # state = [(16000, audio[0]), (16000, audio[1])]
190
- # pdb.set_trace()
191
  pitch_vals = np.where(pitch[0][:, 0] == 0, np.nan, pitch[0].flatten())
192
- fig1 = plt.figure()
193
- # plt.plot(np.arange(0, 400), pitch_vals[:400], figure=fig1, label='User Input')
194
- plt.plot(pitch_vals, figure=fig1, label='Pitch')
195
- # plt.legend(fig1)
196
- # state.append(fig1)
197
- plt.close(fig1)
198
- return (16000, audio[0]), fig1, pitch_vals
199
 
200
- pitch_model, pitch_qt, pitch_task_fn, invert_pitch_fn = load_pitch_fns()
201
- audio_model, audio_qt, audio_seq_len, invert_audio_fn = load_audio_fns()
202
- partial_generate = partial(generate, num_samples=1, num_steps=100, singers=[3], outfolder=None, pitch_qt=pitch_qt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
- @spaces.GPU(duration=180)
205
- def set_prime_and_generate(audio, full_pitch, full_audio, full_user):
206
  global selected_prime, pitch_task_fn
207
 
208
  if audio is None:
@@ -215,40 +146,32 @@ def set_prime_and_generate(audio, full_pitch, full_audio, full_user):
215
  audio /= np.max(np.abs(audio))
216
  audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) # convert only last 4 s
217
  mic_audio = audio.copy()
218
- audio = audio[-12*16000:]
219
  _, f0, _ = extract_pitch(audio)
220
- mic_f0 = f0.copy()
221
- f0 = pitch_task_fn({
222
- 'pitch': {
223
- 'data': f0,
224
- 'sampling_rate': 100
225
- }
226
- }, qt_transform=pitch_qt)
 
 
 
 
 
 
227
  f0 = f0.reshape(1, 1, -1)
228
  f0 = torch.tensor(f0).to(pitch_model.device).float()
229
- audio, pitch, pitch_vals = partial_generate(f0)
230
- # pdb.set_trace()
231
- full_pitch = np.concatenate((full_pitch, mic_f0, pitch_vals))
232
- full_user = np.concatenate((full_user, ['User'] * len(mic_f0), ['Model'] * len(pitch_vals)))
233
- full_audio[1] = np.concatenate((full_audio[1], mic_audio, audio[1]))
234
- # pdb.set_trace()
235
- fig = plt.figure()
236
- plt.plot(np.arange(0, len(mic_f0)), mic_f0, label='User Input', figure=fig)
237
- plt.close(fig)
238
- return audio, full_pitch, full_audio, full_user, pitch
239
-
240
- def save_session(full_pitch, full_audio, full_user):
241
- pass
242
- # os.makedirs(output_folder, exist_ok=True)
243
- # filename = f'session-{time.time()}'
244
- # logging.log(logging.INFO, f"Saving session to {filename}")
245
- # pd.DataFrame({'pitch': full_pitch, 'time': np.arange(0, len(full_pitch)/100, 0.01), 'user': full_user}).to_csv(os.path.join(output_folder, filename + '.csv'), index=False)
246
- # sf.write(os.path.join(output_folder, filename + '.wav'), full_audio[1], 16000)
247
 
248
  with gr.Blocks() as demo:
249
- full_audio = gr.State((16000, np.array([])))
250
- full_pitch = gr.State(np.array([]))
251
- full_user = gr.State(np.array([]))
252
  with gr.Row():
253
  with gr.Column():
254
  audio = gr.Audio(label="Input")
@@ -257,17 +180,9 @@ with gr.Blocks() as demo:
257
  with gr.Column():
258
  generated_audio = gr.Audio(label="Generated Audio")
259
  generated_pitch = gr.Plot(label="Generated Pitch")
260
- sbmt.click(set_prime_and_generate, inputs=[audio, full_pitch, full_audio, full_user], outputs=[generated_audio, full_pitch, full_audio, full_user, user_input])
261
- save = gr.Button("Save Session")
262
- save.click(save_session, inputs=[full_pitch, full_audio, full_user])
263
-
264
-
265
 
266
  def main(argv):
267
- # audio = np.random.randint(0, high=128, size=(44100*5), dtype=np.int16)
268
- # sr = 44100
269
- # pdb.set_trace()
270
- # p, a = set_prime_and_generate((sr, audio))
271
 
272
  demo.launch(share=True)
273
 
 
1
  import spaces
 
2
  import gradio as gr
3
  import numpy as np
4
  import torch
 
5
  import librosa
6
  import matplotlib.pyplot as plt
7
  import pandas as pd
8
  import os
9
  from functools import partial
10
  import gin
11
+ from gamadhani.utils.generate_utils import load_pitch_fns, load_audio_fns
12
+ import gamadhani.utils.pitch_to_audio_utils as p2a
13
+ from gamadhani.utils.utils import get_device
 
14
  import torchaudio
15
  from absl import app
16
  from torch.nn.functional import interpolate
 
17
  import logging
18
  import crepe
19
  from hmmlearn import hmm
 
20
  import soundfile as sf
21
+ import pdb
22
 
23
  pitch_path = 'models/diffusion_pitch/'
 
24
  audio_path = 'models/pitch_to_audio/'
25
+ device = get_device()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def predict_voicing(confidence):
28
  # https://github.com/marl/crepe/pull/26
 
73
 
74
  return time, f0, confidence
75
 
76
+ def generate_pitch_reinterp(pitch, pitch_model, invert_pitch_fn, num_samples, num_steps, noise_std=0.4):
77
+ '''Generate pitch values for the melodic reinterpretation task'''
78
+ # hardcoding the amount of noise to be added
79
+ noisy_pitch = torch.Tensor(pitch[:, :, -1200:]).to(pitch_model.device) + (torch.normal(mean=0.0, std=noise_std*torch.ones((1200)))).to(pitch_model.device)
80
+ noisy_pitch = torch.clamp(noisy_pitch, -5.19, 5.19) # clipping the pitch values to be within the range of the model
81
  samples = pitch_model.sample_sdedit(noisy_pitch, num_samples, num_steps)
82
+ inverted_pitches = [invert_pitch_fn(f0=samples.detach().cpu().numpy()[0])[0]] # pitch values in Hz
83
 
 
 
 
 
 
 
 
 
 
84
  return samples, inverted_pitches
85
 
86
+ def generate_audio(audio_model, f0s, invert_audio_fn, singers=[3], num_steps=100):
87
+ '''Generate audio given pitch values'''
88
  singer_tensor = torch.tensor(np.repeat(singers, repeats=f0s.shape[0])).to(audio_model.device)
89
  samples, _, singers = audio_model.sample_cfg(f0s.shape[0], f0=f0s, num_steps=num_steps, singer=singer_tensor, strength=3)
90
  audio = invert_audio_fn(samples)
91
+
 
 
 
 
 
92
  return audio
93
 
94
  @spaces.GPU(duration=120)
95
+ def generate(pitch, num_samples=1, num_steps=100, singers=[3], outfolder='temp', audio_seq_len=750, pitch_qt=None ):
96
+
 
 
97
  logging.log(logging.INFO, 'Generate function')
98
+ pitch, inverted_pitch = generate_pitch_reinterp(pitch, pitch_model, invert_pitch_fn, num_samples=num_samples, num_steps=100)
99
  if pitch_qt is not None:
100
+ # if there is not pitch quantile transformer, undo the default quantile transformation that occurs
101
  def undo_qt(x, min_clip=200):
102
  pitch= pitch_qt.inverse_transform(x.reshape(-1, 1)).reshape(1, -1)
103
  pitch = np.around(pitch) # round to nearest integer, done in preprocessing of pitch contour fed into model
104
  pitch[pitch < 200] = np.nan
105
  return pitch
106
  pitch = torch.tensor(np.array([undo_qt(x) for x in pitch.detach().cpu().numpy()])).to(pitch_model.device)
107
+ interpolated_pitch = p2a.interpolate_pitch(pitch=pitch, audio_seq_len=audio_seq_len) # interpolate pitch values to match the audio model's input size
108
+ interpolated_pitch = torch.nan_to_num(interpolated_pitch, nan=196) # replace nan values with silent token
109
  interpolated_pitch = interpolated_pitch.squeeze(1) # to match input size by removing the extra dimension
110
+ audio = generate_audio(audio_model, interpolated_pitch, invert_audio_fn, singers=singers, num_steps=100)
111
+ audio = audio.detach().cpu().numpy()
 
112
  pitch = pitch.detach().cpu().numpy()
 
 
113
  pitch_vals = np.where(pitch[0][:, 0] == 0, np.nan, pitch[0].flatten())
 
 
 
 
 
 
 
114
 
115
+ # generate plot of model output to display on interface
116
+ model_output_plot = plt.figure()
117
+ plt.plot(pitch_vals, figure=model_output_plot, label='Model Output')
118
+ plt.close(model_output_plot)
119
+ return (16000, audio[0]), model_output_plot, pitch_vals
120
+
121
+ # pdb.set_trace()
122
+ pitch_model, pitch_qt, pitch_task_fn, invert_pitch_fn, _ = load_pitch_fns(
123
+ os.path.join(pitch_path, 'last.ckpt'), \
124
+ model_type = 'diffusion', \
125
+ config_path = os.path.join(pitch_path, 'config.gin'), \
126
+ qt_path = os.path.join(pitch_path, 'qt.joblib'), \
127
+ )
128
+ audio_model, audio_qt, audio_seq_len, invert_audio_fn = load_audio_fns(
129
+ os.path.join(audio_path, 'last.ckpt'),
130
+ qt_path = os.path.join(audio_path, 'qt.joblib'),
131
+ config_path = os.path.join(audio_path, 'config.gin')
132
+ )
133
+ partial_generate = partial(generate, num_samples=1, num_steps=100, singers=[3], outfolder=None, pitch_qt=pitch_qt) # generate function with default arguments
134
 
135
+ @spaces.GPU(duration=120)
136
+ def set_guide_and_generate(audio):
137
  global selected_prime, pitch_task_fn
138
 
139
  if audio is None:
 
146
  audio /= np.max(np.abs(audio))
147
  audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) # convert only last 4 s
148
  mic_audio = audio.copy()
149
+ audio = audio[-12*16000:] # consider only last 12 s
150
  _, f0, _ = extract_pitch(audio)
151
+ mic_f0 = f0.copy() # save the user input pitch values
152
+ f0 = pitch_task_fn(**{
153
+ 'inputs': {
154
+ 'pitch': {
155
+ 'data': torch.Tensor(f0), # task function expects a tensor
156
+ 'sampling_rate': 100
157
+ }
158
+ },
159
+ 'qt_transform': pitch_qt,
160
+ 'time_downsample': 1, # pitch will be extracted at 100 Hz, thus no downsampling
161
+ 'seq_len': None,
162
+ })['sampled_sequence']
163
+ # pdb.set_trace()
164
  f0 = f0.reshape(1, 1, -1)
165
  f0 = torch.tensor(f0).to(pitch_model.device).float()
166
+ audio, pitch, _ = partial_generate(f0)
167
+ mic_f0 = np.where(mic_f0 == 0, np.nan, mic_f0)
168
+ # plot user input
169
+ user_input_plot = plt.figure()
170
+ plt.plot(np.arange(0, len(mic_f0)), mic_f0, label='User Input', figure=user_input_plot)
171
+ plt.close(user_input_plot)
172
+ return audio, user_input_plot, pitch
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  with gr.Blocks() as demo:
 
 
 
175
  with gr.Row():
176
  with gr.Column():
177
  audio = gr.Audio(label="Input")
 
180
  with gr.Column():
181
  generated_audio = gr.Audio(label="Generated Audio")
182
  generated_pitch = gr.Plot(label="Generated Pitch")
183
+ sbmt.click(set_guide_and_generate, inputs=[audio], outputs=[generated_audio, user_input, generated_pitch])
 
 
 
 
184
 
185
  def main(argv):
 
 
 
 
186
 
187
  demo.launch(share=True)
188
 
models/diffusion_pitch/config.gin CHANGED
@@ -1,7 +1,9 @@
1
  from __gin__ import dynamic_registration
2
- from src import dataset
3
- from src import model
4
- from src import utils
 
 
5
  import torch
6
 
7
  # Macros:
@@ -23,47 +25,46 @@ utils.build_warmed_exponential_lr_scheduler.eta_min = 0.1
23
  utils.build_warmed_exponential_lr_scheduler.peak_iteration = 10000
24
  utils.build_warmed_exponential_lr_scheduler.start_factor = 0.01
25
 
26
- # Parameters for model.UNetBase.configure_optimizers:
27
  # ==============================================================================
28
- model.UNetBase.configure_optimizers.optimizer_cls = @torch.optim.AdamW
29
- model.UNetBase.configure_optimizers.scheduler_cls = \
30
  @utils.build_warmed_exponential_lr_scheduler
31
 
32
- # Parameters for dataset.pitch_read_w_downsample:
33
  # ==============================================================================
34
- dataset.pitch_read_w_downsample.add_noise_to_silence = True
35
- dataset.pitch_read_w_downsample.decoder_key = 'pitch'
36
- dataset.pitch_read_w_downsample.max_clip = 600
37
- dataset.pitch_read_w_downsample.min_clip = 200
38
- dataset.pitch_read_w_downsample.min_norm_pitch = -4915
39
- dataset.pitch_read_w_downsample.pitch_downsample = 10
40
- dataset.pitch_read_w_downsample.seq_len = %SEQ_LEN
41
- dataset.pitch_read_w_downsample.time_downsample = 2
 
42
 
43
  # Parameters for train/dataset.pitch_read_w_downsample:
44
  # ==============================================================================
45
- train/dataset.pitch_read_w_downsample.transpose_pitch = %TRANSPOSE_VALUE
46
 
47
- # Parameters for train/dataset.SequenceDataset:
48
  # ==============================================================================
49
- train/dataset.SequenceDataset.task_fn = @train/dataset.pitch_read_w_downsample
 
50
 
51
- # Parameters for val/dataset.SequenceDataset:
52
- # ==============================================================================
53
- val/dataset.SequenceDataset.task_fn = @dataset.pitch_read_w_downsample
54
 
55
- # Parameters for model.UNet:
56
  # ==============================================================================
57
- model.UNet.dropout = 0.3
58
- model.UNet.features = [512, 640, 1024]
59
- model.UNet.inp_dim = 1
60
- model.UNet.kernel_size = 5
61
- model.UNet.nonlinearity = 'mish'
62
- model.UNet.norm = True
63
- model.UNet.num_attns = 4
64
- model.UNet.num_convs = 4
65
- model.UNet.num_heads = 8
66
- model.UNet.project_dim = 256
67
- model.UNet.seq_len = %SEQ_LEN
68
- model.UNet.strides = [4, 2, 2]
69
- model.UNet.time_dim = 128
 
1
  from __gin__ import dynamic_registration
2
+ from gamadhani import src
3
+ from gamadhani.src import dataset
4
+ from gamadhani.src import model_diffusion
5
+ from gamadhani.src import task_functions
6
+ from gamadhani.utils import utils
7
  import torch
8
 
9
  # Macros:
 
25
  utils.build_warmed_exponential_lr_scheduler.peak_iteration = 10000
26
  utils.build_warmed_exponential_lr_scheduler.start_factor = 0.01
27
 
28
+ # Parameters for model_diffusion.UNetBase.configure_optimizers:
29
  # ==============================================================================
30
+ model_diffusion.UNetBase.configure_optimizers.optimizer_cls = @torch.optim.AdamW
31
+ model_diffusion.UNetBase.configure_optimizers.scheduler_cls = \
32
  @utils.build_warmed_exponential_lr_scheduler
33
 
34
+ # Parameters for dataset.Task:
35
  # ==============================================================================
36
+ src.dataset.Task.kwargs = {
37
+ "decoder_key" : 'pitch',
38
+ "max_clip" : 600,
39
+ "min_clip" : 200,
40
+ "min_norm_pitch" : -4915,
41
+ "pitch_downsample" : 10,
42
+ "seq_len" : %SEQ_LEN,
43
+ "time_downsample" : 2}
44
+
45
 
46
  # Parameters for train/dataset.pitch_read_w_downsample:
47
  # ==============================================================================
48
+ # train/dataset.Task.kwargs = {"transpose_pitch": %TRANSPOSE_VALUE}
49
 
50
+ # Parameters for train/dataset.Task:
51
  # ==============================================================================
52
+ src.dataset.Task.read_fn = @src.task_functions.pitch_read_downsample_diff
53
+ src.dataset.Task.invert_fn = @src.task_functions.invert_pitch_read_downsample_diff
54
 
 
 
 
55
 
56
+ # Parameters for model_diffusion.UNet:
57
  # ==============================================================================
58
+ model_diffusion.UNet.dropout = 0.3
59
+ model_diffusion.UNet.features = [512, 640, 1024]
60
+ model_diffusion.UNet.inp_dim = 1
61
+ model_diffusion.UNet.kernel_size = 5
62
+ model_diffusion.UNet.nonlinearity = 'mish'
63
+ model_diffusion.UNet.norm = True
64
+ model_diffusion.UNet.num_attns = 4
65
+ model_diffusion.UNet.num_convs = 4
66
+ model_diffusion.UNet.num_heads = 8
67
+ model_diffusion.UNet.project_dim = 256
68
+ model_diffusion.UNet.seq_len = %SEQ_LEN
69
+ model_diffusion.UNet.strides = [4, 2, 2]
70
+ model_diffusion.UNet.time_dim = 128
models/pitch_to_audio/config.gin CHANGED
@@ -1,8 +1,9 @@
1
  from __gin__ import dynamic_registration
2
- from src import dataset
3
- from src import model
4
- from src import pitch_to_audio_utils
5
- from src import utils
 
6
  import torch
7
 
8
  # Macros:
@@ -27,10 +28,10 @@ utils.build_warmed_exponential_lr_scheduler.eta_min = 0.1
27
  utils.build_warmed_exponential_lr_scheduler.peak_iteration = 10000
28
  utils.build_warmed_exponential_lr_scheduler.start_factor = 0.01
29
 
30
- # Parameters for model.UNetBase.configure_optimizers:
31
  # ==============================================================================
32
- model.UNetBase.configure_optimizers.optimizer_cls = @torch.optim.AdamW
33
- model.UNetBase.configure_optimizers.scheduler_cls = \
34
  @utils.build_warmed_exponential_lr_scheduler
35
 
36
  # Parameters for pitch_to_audio_utils.from_mels:
@@ -39,11 +40,6 @@ pitch_to_audio_utils.from_mels.nfft = %NFFT
39
  pitch_to_audio_utils.from_mels.num_mels = %NUM_MELS
40
  pitch_to_audio_utils.from_mels.sr = %SR
41
 
42
- # Parameters for dataset.load_cached_dataset:
43
- # ==============================================================================
44
- dataset.load_cached_dataset.audio_len = %AUDIO_SEQ_LEN
45
- dataset.load_cached_dataset.return_singer = %SINGER_CONDITIONING
46
-
47
  # Parameters for pitch_to_audio_utils.normalized_mels_to_audio:
48
  # ==============================================================================
49
  pitch_to_audio_utils.normalized_mels_to_audio.n_iter = 100
@@ -53,7 +49,13 @@ pitch_to_audio_utils.normalized_mels_to_audio.sr = %SR
53
 
54
  # Parameters for dataset.SequenceDataset:
55
  # ==============================================================================
56
- dataset.SequenceDataset.task_fn = @dataset.load_cached_dataset
 
 
 
 
 
 
57
 
58
  # Parameters for pitch_to_audio_utils.torch_gl:
59
  # ==============================================================================
@@ -65,27 +67,28 @@ pitch_to_audio_utils.torch_gl.sr = %SR
65
  # ==============================================================================
66
  pitch_to_audio_utils.torch_istft.nfft = %NFFT
67
 
68
- # Parameters for model.UNetPitchConditioned:
69
  # ==============================================================================
70
- model.UNetPitchConditioned.audio_seq_len = %AUDIO_SEQ_LEN
71
- model.UNetPitchConditioned.cfg = True
72
- model.UNetPitchConditioned.cond_drop_prob = 0.2
73
- model.UNetPitchConditioned.dropout = 0.3
74
- model.UNetPitchConditioned.f0_dim = 128
75
- model.UNetPitchConditioned.features = [512, 640, 1024]
76
- model.UNetPitchConditioned.inp_dim = %NUM_MELS
77
- model.UNetPitchConditioned.kernel_size = 5
78
- model.UNetPitchConditioned.log_samples_every = 10
79
- model.UNetPitchConditioned.log_wandb_samples_every = 50
80
- model.UNetPitchConditioned.nonlinearity = 'mish'
81
- model.UNetPitchConditioned.norm = False
82
- model.UNetPitchConditioned.num_attns = 4
83
- model.UNetPitchConditioned.num_convs = 4
84
- model.UNetPitchConditioned.num_heads = 8
85
- model.UNetPitchConditioned.project_dim = 256
86
- model.UNetPitchConditioned.singer_conditioning = %SINGER_CONDITIONING
87
- model.UNetPitchConditioned.singer_dim = 128
88
- model.UNetPitchConditioned.singer_vocab = 55
89
- model.UNetPitchConditioned.sr = %SR
90
- model.UNetPitchConditioned.strides = [4, 2, 2]
91
- model.UNetPitchConditioned.time_dim = 128
 
 
1
  from __gin__ import dynamic_registration
2
+ from gamadhani import src
3
+ from gamadhani.src import dataset
4
+ from gamadhani.src import model_diffusion
5
+ from gamadhani.utils import pitch_to_audio_utils
6
+ from gamadhani.utils import utils
7
  import torch
8
 
9
  # Macros:
 
28
  utils.build_warmed_exponential_lr_scheduler.peak_iteration = 10000
29
  utils.build_warmed_exponential_lr_scheduler.start_factor = 0.01
30
 
31
+ # Parameters for model_diffusion.UNetPitchConditioned.configure_optimizers:
32
  # ==============================================================================
33
+ model_diffusion.UNetPitchConditioned.configure_optimizers.optimizer_cls = @torch.optim.AdamW
34
+ model_diffusion.UNetPitchConditioned.configure_optimizers.scheduler_cls = \
35
  @utils.build_warmed_exponential_lr_scheduler
36
 
37
  # Parameters for pitch_to_audio_utils.from_mels:
 
40
  pitch_to_audio_utils.from_mels.num_mels = %NUM_MELS
41
  pitch_to_audio_utils.from_mels.sr = %SR
42
 
 
 
 
 
 
43
  # Parameters for pitch_to_audio_utils.normalized_mels_to_audio:
44
  # ==============================================================================
45
  pitch_to_audio_utils.normalized_mels_to_audio.n_iter = 100
 
49
 
50
  # Parameters for dataset.SequenceDataset:
51
  # ==============================================================================
52
+ dataset.SequenceDataset.task = @dataset.Task()
53
+
54
+ # Parameters for dataset.Task:
55
+ # ==============================================================================
56
+ dataset.Task.read_fn = @dataset.load_cached_dataset
57
+ dataset.Task.kwargs = {"audio_len": %AUDIO_SEQ_LEN,
58
+ "return_singer": %SINGER_CONDITIONING}
59
 
60
  # Parameters for pitch_to_audio_utils.torch_gl:
61
  # ==============================================================================
 
67
  # ==============================================================================
68
  pitch_to_audio_utils.torch_istft.nfft = %NFFT
69
 
70
+ # Parameters for model_diffusion.UNetPitchConditioned:
71
  # ==============================================================================
72
+ model_diffusion.UNetPitchConditioned.audio_seq_len = %AUDIO_SEQ_LEN
73
+ model_diffusion.UNetPitchConditioned.cfg = True
74
+ model_diffusion.UNetPitchConditioned.cond_drop_prob = 0.2
75
+ model_diffusion.UNetPitchConditioned.dropout = 0.3
76
+ model_diffusion.UNetPitchConditioned.f0_dim = 128
77
+ model_diffusion.UNetPitchConditioned.features = [512, 640, 1024]
78
+ model_diffusion.UNetPitchConditioned.inp_dim = %NUM_MELS
79
+ model_diffusion.UNetPitchConditioned.kernel_size = 5
80
+ model_diffusion.UNetPitchConditioned.log_samples_every = 10
81
+ model_diffusion.UNetPitchConditioned.log_wandb_samples_every = 50
82
+ model_diffusion.UNetPitchConditioned.loss_w_padding = True
83
+ model_diffusion.UNetPitchConditioned.nonlinearity = 'mish'
84
+ model_diffusion.UNetPitchConditioned.norm = False
85
+ model_diffusion.UNetPitchConditioned.num_attns = 4
86
+ model_diffusion.UNetPitchConditioned.num_convs = 4
87
+ model_diffusion.UNetPitchConditioned.num_heads = 8
88
+ model_diffusion.UNetPitchConditioned.project_dim = 256
89
+ model_diffusion.UNetPitchConditioned.singer_conditioning = %SINGER_CONDITIONING
90
+ model_diffusion.UNetPitchConditioned.singer_dim = 128
91
+ model_diffusion.UNetPitchConditioned.singer_vocab = 55
92
+ model_diffusion.UNetPitchConditioned.sr = %SR
93
+ model_diffusion.UNetPitchConditioned.strides = [4, 2, 2]
94
+ model_diffusion.UNetPitchConditioned.time_dim = 128
requirements.txt CHANGED
@@ -1,22 +1,4 @@
1
- absl_py==1.4.0
2
- einops==0.8.0
3
- gin_config==0.5.0
4
- joblib==1.2.0
5
- librosa==0.10.0
6
- lmdb==1.4.1
7
- matplotlib==3.9.2
8
- numpy==1.24.4
9
- pandas==2.0.3
10
- protobuf==3.20.3
11
- pytorch_lightning==1.9.0
12
- scikit_learn==1.2.0
13
- setuptools==67.8.0
14
- torch==2.4.0
15
- torchaudio==2.4.0
16
- tqdm==4.65.0
17
- wandb==0.15.4
18
- x_transformers==1.30.2
19
  crepe==0.0.15
20
  hmmlearn==0.3.2
21
  tensorflow==2.17.0
22
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  crepe==0.0.15
2
  hmmlearn==0.3.2
3
  tensorflow==2.17.0
4
+ GaMaDHaNi @ git+https://github.com/snnithya/GaMaDHaNi.git@782dde8f48ff15a50394bcc7506df1ece0e0310e
src/dataset.py DELETED
@@ -1,312 +0,0 @@
1
- from typing import Callable, Dict, Optional, Tuple
2
- import lmdb
3
- import torch
4
- import pdb
5
- import numpy as np
6
- from torch.utils.data import Dataset
7
- from random import randint
8
- from sklearn.preprocessing import QuantileTransformer
9
- # from protobuf.data_example import AudioExample
10
- import gin
11
- import sys
12
- import src.pitch_to_audio_utils as p2a
13
-
14
- TensorDict = Dict[str, torch.Tensor]
15
-
16
- @gin.configurable
17
- class SequenceDataset(Dataset):
18
-
19
- def __init__(
20
- self,
21
- db_path: str,
22
- task_fn: Optional[Callable[[TensorDict], TensorDict]] = None,
23
- device: Optional[torch.device] = None
24
- ) -> None:
25
- super().__init__()
26
- self._env = None
27
- self._keys = None
28
- self._db_path = db_path
29
- self.task_fn = task_fn
30
- self.device = device
31
-
32
- def __len__(self):
33
- return len(self.keys)
34
-
35
- def __getitem__(self, index):
36
- # pdb.set_trace()
37
- with self.env.begin() as txn:
38
- ae = AudioExample(txn.get(self.keys[index]))
39
- ae = ae.as_dict()
40
- if self.task_fn is not None:
41
- ae = self.task_fn(ae)
42
- if self.device is not None:
43
- ae = {k: torch.tensor(v, device=self.device) for k, v in ae.items()}
44
- return ae
45
-
46
- @property
47
- def env(self):
48
- if self._env is None:
49
- self._env = lmdb.open(
50
- self._db_path,
51
- lock=False,
52
- readahead=False,
53
- )
54
- return self._env
55
-
56
- @property
57
- def keys(self):
58
- if self._keys is None:
59
- with self.env.begin(write=False) as txn:
60
- self._keys = list(txn.cursor().iternext(values=False))
61
- self._keys = self._keys
62
- return self._keys
63
-
64
- class MelPitchDataLoader(torch.utils.data.DataLoader):
65
- def __init__(self, *args, **kwargs):
66
- super().__init__(*args, **kwargs)
67
-
68
- def __iter__(self):
69
- for batch in super().__iter__():
70
- # Apply online transform to each sample in the batch
71
- audio, f0 = batch
72
-
73
- # generate mel spectrogram
74
- mel = p2a.audio_to_normalized_mels(audio) # doing mel conversion here since it is done in a batch and thus presumably faster
75
-
76
- yield zip(mel, f0)
77
-
78
- @gin.configurable
79
- def pitch_read_w_downsample(
80
- inputs: TensorDict,
81
- seq_len: int,
82
- decoder_key: str,
83
- min_norm_pitch: int,
84
- transpose_pitch: Optional[int] = None,
85
- time_downsample: int = 1,
86
- pitch_downsample: int = 1,
87
- qt_transform: Optional[QuantileTransformer] = None,
88
- min_clip: int = 200,
89
- max_clip: int = 600,
90
- add_noise_to_silence: bool = False
91
- ):
92
- # pdb.set_trace()
93
- # print(min_norm_pitch, seq_len, transpose_pitch, qt_transform)
94
- data = inputs[decoder_key]["data"]
95
- if seq_len is not None:
96
- start = randint(0, data.shape[0] - seq_len*time_downsample - 1)
97
- end = start + seq_len*time_downsample
98
- f0 = inputs[decoder_key]['data'][start:end:time_downsample].copy()
99
- else:
100
- f0 = data.copy()
101
-
102
- # normalize pitch
103
- f0[f0 == 0] = np.nan
104
- norm_f0 = f0.copy()
105
- norm_f0[~np.isnan(norm_f0)] = (1200) * np.log2(norm_f0[~np.isnan(norm_f0)] / 440)
106
- del f0
107
-
108
- # descretize pitch
109
- norm_f0[~np.isnan(norm_f0)] = np.around(norm_f0[~np.isnan(norm_f0)])
110
- norm_f0[~np.isnan(norm_f0)] = norm_f0[~np.isnan(norm_f0)] - (min_norm_pitch)
111
-
112
- norm_f0[~np.isnan(norm_f0)] = norm_f0[~np.isnan(norm_f0)] // pitch_downsample + 1 # reserve 0 for silence
113
- # data augmentation
114
- if transpose_pitch:
115
- transpose_amt = randint(-transpose_pitch, transpose_pitch) # in cents
116
- transposed_values = norm_f0[~np.isnan(norm_f0)] + (transpose_amt//pitch_downsample)
117
- norm_f0[~np.isnan(norm_f0)] = transposed_values
118
-
119
- # clip values HACK to change
120
- norm_f0[~np.isnan(norm_f0)] = np.clip(norm_f0[~np.isnan(norm_f0)], min_clip, max_clip)
121
-
122
- # add silence token of min_clip - 4
123
- if add_noise_to_silence:
124
- norm_f0[np.isnan(norm_f0)] = min_clip - 4 + np.clip(np.random.normal(size=norm_f0[np.isnan(norm_f0)].shape), -3, 3) # making sure noise is between -3 and 3 and thus won't spill into pitched values
125
- else:
126
- norm_f0[np.isnan(norm_f0)] = min_clip - 4
127
-
128
- if qt_transform:
129
- qt_inp = norm_f0.reshape(-1, 1)
130
- norm_f0 = qt_transform.transform(qt_inp).reshape(-1)
131
-
132
- return norm_f0.reshape(1, -1)
133
-
134
- def hz_to_cents(f0, ref=440, min_norm_pitch=0, pitch_downsample=1, min_clip=200, max_clip=600, silence_token=None):
135
- # pdb.set_trace()
136
- f0[f0 == 0] = np.nan
137
- norm_f0 = f0.copy()
138
- norm_f0[~np.isnan(norm_f0)] = (1200) * np.log2(norm_f0[~np.isnan(norm_f0)] / ref)
139
- # descretize pitch
140
- norm_f0[~np.isnan(norm_f0)] = np.around(norm_f0[~np.isnan(norm_f0)])
141
- norm_f0[~np.isnan(norm_f0)] = norm_f0[~np.isnan(norm_f0)] - (min_norm_pitch)
142
- norm_f0[~np.isnan(norm_f0)] = norm_f0[~np.isnan(norm_f0)] // pitch_downsample + 1 # reserve 0 for silence
143
- norm_f0[~np.isnan(norm_f0)] = np.clip(norm_f0[~np.isnan(norm_f0)], min_clip, max_clip) #HACK
144
- if silence_token is not None:
145
- norm_f0[np.isnan(norm_f0)] = silence_token
146
-
147
-
148
-
149
- return norm_f0
150
-
151
- @gin.configurable
152
- def mel_pitch(
153
- inputs: TensorDict,
154
- min_norm_pitch: int,
155
- audio_seq_len: int=None,
156
- pitch_downsample: int = 1,
157
- qt_transform: Optional[QuantileTransformer] = None,
158
- min_clip: int = 200,
159
- max_clip: int = 600,
160
- nfft: int = 2048,
161
- convert_audio_to_mel: bool = False
162
- ):
163
- hop_size = nfft // 4
164
- audio_data = inputs['audio']['data']
165
- audio_sr = inputs['audio']['sampling_rate']
166
- pitch_data = inputs['pitch']['data']
167
- pitch_sr = inputs['pitch']['sampling_rate']
168
- # pdb.set_trace()
169
- if audio_seq_len is not None:
170
- # if audio_seq_len is given, cuts audio/pitch else returns the entire chunk
171
- pitch_seq_len = np.around((audio_seq_len/audio_sr) * pitch_sr ).astype(int)
172
- pitch_start = randint(0, pitch_data.shape[0] - pitch_seq_len - 1)
173
- pitch_end = pitch_start + pitch_seq_len
174
- pitch_data = pitch_data[pitch_start:pitch_end]
175
- audio_start = np.around(pitch_start * audio_sr // pitch_sr).astype(int)
176
- audio_end = np.around(audio_start + audio_seq_len).astype(int)
177
- # pdb.set_trace()
178
- audio_data = audio_data[audio_start:audio_end]
179
- else:
180
- pitch_seq_len = np.around((audio_data.shape[0]/audio_sr) * pitch_sr ).astype(int)
181
- audio_data = p2a.audio_to_normalized_mels(torch.Tensor(audio_data).unsqueeze(0), qt=qt_transform).numpy()[0]
182
-
183
- pitch_data = hz_to_cents(pitch_data, min_norm_pitch=min_norm_pitch, pitch_downsample=pitch_downsample, min_clip=min_clip, max_clip=max_clip)
184
-
185
- if audio_seq_len is not None:
186
- # linearly interpolate pitch data to match audio sequence length, if audio_seq_len is given
187
- pitch_inds = np.linspace(0, pitch_data.shape[0], num=audio_seq_len//hop_size, endpoint=False) #check here
188
- pitch_data = np.interp(pitch_inds, np.arange(0, pitch_data.shape[0]), pitch_data)
189
-
190
- # replace nan (aka silences) with min_clip - 4
191
- pitch_data[np.isnan(pitch_data)] = min_clip - 4
192
-
193
- return audio_data, pitch_data
194
- def running_average(signal, window_size):
195
-
196
- weights = np.ones(int(window_size)) / window_size
197
- pad_width = len(weights) // 2
198
- padded_signal = np.pad(signal, pad_width, mode='symmetric')
199
- # Perform the convolution
200
- smoothed_signal = np.convolve(padded_signal, weights, mode='valid')
201
- if window_size % 2 == 0:
202
- smoothed_signal = smoothed_signal[:-1]
203
- return smoothed_signal
204
-
205
- @gin.configurable
206
- def pitch_coarse_condition(
207
- inputs: TensorDict,
208
- min_norm_pitch: int,
209
- pitch_seq_len: int=None,
210
- pitch_downsample: int = 1,
211
- time_downsample: int = 1,
212
- qt_transform: Optional[QuantileTransformer] = None,
213
- min_clip: int = 200,
214
- max_clip: int = 600,
215
- add_noise: bool = True,
216
- avg_window_size: float = 1 # window size in seconds
217
- ):
218
-
219
- pitch_data = inputs['pitch']['data']
220
- if pitch_seq_len is not None:
221
- pitch_start = randint(0, pitch_data.shape[0] - pitch_seq_len*time_downsample - 1)
222
- pitch_end = pitch_start + pitch_seq_len*time_downsample
223
- pitch_data = pitch_data[pitch_start:pitch_end:time_downsample]
224
- pitch_data = hz_to_cents(pitch_data, min_norm_pitch=min_norm_pitch, pitch_downsample=pitch_downsample, min_clip=min_clip, max_clip=max_clip)
225
-
226
- # extract coarse pitch condition
227
- pitch_sr = inputs['pitch']['sampling_rate'] // time_downsample
228
- avg_pitch = running_average(pitch_data, np.around(pitch_sr * avg_window_size).astype(int))
229
- # replace nan (aka silences) with min_clip - 4
230
- if add_noise:
231
- pitch_data[np.isnan(pitch_data)] = min_clip - 4 + np.clip(np.random.normal(size=pitch_data[np.isnan(pitch_data)].shape), -3, 3) # making sure noise is between -3 and 3 and thus won't spill into pitched values
232
- avg_pitch[np.isnan(avg_pitch)] = min_clip - 4 + np.clip(np.random.normal(size=avg_pitch[np.isnan(avg_pitch)].shape), -3, 3) # making sure noise is between -3 and 3 and thus won't spill into pitched values
233
- else:
234
- pitch_data[np.isnan(pitch_data)] = min_clip - 4
235
-
236
- if qt_transform:
237
- # apply qt transform
238
- qt_inp = pitch_data.reshape(-1, 1)
239
- pitch_data = qt_transform.transform(qt_inp).reshape(-1)
240
- avg_qt_inp = avg_pitch.reshape(-1, 1)
241
- avg_pitch = qt_transform.transform(avg_qt_inp).reshape(-1)
242
- # pdb.set_trace()
243
- return pitch_data, avg_pitch
244
-
245
- @gin.configurable
246
- def mel_pitch_coarse_condition(
247
- inputs: TensorDict,
248
- min_norm_pitch: int,
249
- audio_seq_len: int=None,
250
- pitch_downsample: int = 1,
251
- qt_transform: Optional[QuantileTransformer] = None,
252
- min_clip: int = 200,
253
- max_clip: int = 600,
254
- nfft: int = 2048,
255
- avg_window_size: float = 1 # duration of avg window in seconds
256
- ):
257
- mel, pitch = mel_pitch(inputs, min_norm_pitch, audio_seq_len, pitch_downsample, qt_transform, min_clip, max_clip, nfft)
258
- silence_token = min_clip - 4
259
- avg_pitch = pitch.copy()
260
- avg_pitch[pitch == silence_token] = np.nan
261
-
262
- time = mel.shape[1]/inputs['audio']['sampling_rate']
263
- pitch_sr = pitch.shape[0]/time
264
-
265
- avg_pitch = running_average(avg_pitch, np.around(pitch_sr*avg_window_size))
266
- avg_pitch[np.isnan(avg_pitch)] = silence_token
267
-
268
- return mel, pitch, avg_pitch
269
-
270
- def load_cached_audio(
271
- inputs: TensorDict,
272
- audio_len: Optional[float] = None,
273
- ) -> torch.Tensor:
274
-
275
- audio_data = inputs['audio']['data']
276
- if audio_len is not None:
277
- audio_start = randint(0, audio_data.shape[1] - audio_len - 1)
278
- audio_end = audio_start + audio_len
279
- audio_data = audio_data[:, audio_start:audio_end]
280
- return torch.Tensor(audio_data)
281
-
282
- # need to add a silence token / range, calculate pitch avg
283
- def load_cached_dataset(
284
- inputs: TensorDict,
285
- audio_len: float,
286
- return_singer: bool = False
287
- ) -> Tuple[torch.Tensor, torch.Tensor]:
288
- # pdb.set_trace()
289
- audio_sr = inputs['audio']['sampling_rate']
290
- audio_data = inputs['audio']['data']
291
- audio_start = randint(0, audio_data.shape[1] - audio_len - 1)
292
- audio_end = audio_start + audio_len
293
- audio_data = audio_data[:, audio_start:audio_end]
294
-
295
- pitch_sr = inputs['pitch']['sampling_rate']
296
- pitch_len = np.floor(audio_len / audio_sr * pitch_sr).astype(int)
297
- pitch_data = inputs['pitch']['data']
298
- pitch_start = np.floor(audio_start * pitch_sr / audio_sr).astype(int)
299
- pitch_end = pitch_start + pitch_len
300
- pitch_data = pitch_data[pitch_start:pitch_end]
301
-
302
- # interpolate data to match audio length
303
- pitch_inds = np.linspace(0, pitch_data.shape[0], num=audio_len, endpoint=False) #check here
304
- pitch_data = np.interp(pitch_inds, np.arange(0, pitch_data.shape[0]), pitch_data)
305
-
306
- if return_singer:
307
- singer = torch.Tensor([inputs['global_conditions']['singer']])
308
- else:
309
- singer = None
310
-
311
- # print(audio_data.shape, pitch_data.shape, singer.shape if singer is not None else None)
312
- return torch.Tensor(audio_data), torch.Tensor(pitch_data), singer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/generate_utils.py DELETED
@@ -1,88 +0,0 @@
1
- import numpy as np
2
- from typing import Optional
3
- from sklearn.preprocessing import QuantileTransformer
4
- import sys
5
- import pdb
6
- sys.path.append('../pitch-diffusion')
7
- import torch
8
- import gin
9
- from src.model import UNet, UNetPitchConditioned
10
- from functools import partial
11
- import joblib
12
- from src.dataset import hz_to_cents, pitch_read_w_downsample
13
-
14
- def invert_pitch_read(pitch,
15
- min_norm_pitch: int,
16
- time_downsample: int,
17
- pitch_downsample: int,
18
- qt_transform: Optional[QuantileTransformer],
19
- min_clip: int,
20
- max_clip: int):
21
- try:
22
- pitch = pitch.detach().cpu().numpy()
23
- except:
24
- pass
25
- if qt_transform is not None:
26
- pitch = qt_transform.inverse_transform(pitch.reshape(-1, 1))
27
- pitch.reshape(1, -1)
28
- pitch[pitch < min_clip] = np.nan
29
- pitch[~np.isnan(pitch)] = (pitch[~np.isnan(pitch)] - 1) * pitch_downsample
30
- pitch[~np.isnan(pitch)] = pitch[~np.isnan(pitch)] + min_norm_pitch
31
- pitch[~np.isnan(pitch)] = 440 * 2**(pitch[~np.isnan(pitch)] / 1200)
32
- pitch[np.isnan(pitch)] = 0
33
-
34
- return pitch, 200//time_downsample
35
-
36
- def invert_tonic(tonic: Optional[int] = None,
37
- min_norm_pitch: int = 0,
38
- min_clip: int = 200,
39
- pitch_downsample: int = 1,
40
- ):
41
- tonic += min_clip
42
- tonic = pitch_downsample * (tonic - 1)
43
- tonic += min_norm_pitch
44
- tonic = 440 * 2**(tonic / 1200)
45
-
46
- return tonic
47
-
48
- def load_processed_pitch(pitch,
49
- audio_seq_len: int,
50
- min_norm_pitch: int,
51
- pitch_downsample: int,
52
- min_clip: int,
53
- max_clip: int,
54
- ):
55
- # pdb.set_trace()
56
- pitch = hz_to_cents(pitch, min_norm_pitch=min_norm_pitch, min_clip=min_clip, max_clip=max_clip, pitch_downsample=pitch_downsample, silence_token=min_clip-4)
57
- pitch_inds = np.linspace(0, pitch.shape[0], num=audio_seq_len, endpoint=False)
58
- pitch = np.interp(pitch_inds, np.arange(0, pitch.shape[0]), pitch)
59
- return pitch
60
-
61
- def load_pitch_model(config, ckpt, qt = None, prime_file=None, device='cuda'):
62
- gin.parse_config_file(config)
63
- model = UNet()
64
- model.load_state_dict(torch.load(ckpt, map_location='cuda')['state_dict'])
65
- model.to(device)
66
- if qt is not None:
67
- qt = joblib.load(qt)
68
- if prime_file is not None:
69
- with gin.config_scope('val'): # probably have to change this
70
- with gin.unlock_config():
71
- gin.bind_parameter('dataset.pitch_read_w_downsample.qt_transform', qt)
72
- primes = np.load(prime_file, allow_pickle=True)['concatenated_array'][:, 0]
73
- else:
74
- primes = None
75
- task_fn = None
76
- task_fn = partial(pitch_read_w_downsample,
77
- seq_len=None)
78
- return model, qt, primes, task_fn
79
-
80
- def load_audio_model(config, ckpt, qt = None, device='cuda'):
81
- gin.parse_config_file(config)
82
- model = UNetPitchConditioned() # there are no gin parameters for some reason
83
- model.load_state_dict(torch.load(ckpt, map_location='cuda')['state_dict'])
84
- model.to(device)
85
- if qt is not None:
86
- qt = joblib.load(qt)
87
-
88
- return model, qt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/model.py DELETED
@@ -1,1130 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.optim as optim
4
- import pytorch_lightning as pl
5
- import torch.nn.functional as F
6
- import math
7
- from typing import Optional, Union
8
- import numpy as np
9
- import wandb
10
- import matplotlib.pyplot as plt
11
- import gin
12
- import os
13
- import pandas as pd
14
- import src.pitch_to_audio_utils as p2a
15
- import torchaudio
16
- from typing import Callable
17
- from pytorch_lightning.utilities import grad_norm
18
-
19
- import sys
20
- sys.path.append('..')
21
- sys.path.append('../x-transformers/')
22
- from src.utils import prob_mask_like
23
- from x_transformers.x_transformers import AttentionLayers
24
- import pdb
25
-
26
- def get_activation(act: str = 'mish'):
27
- act = act.lower()
28
- if act == 'mish':
29
- return nn.Mish()
30
- elif act == 'relu':
31
- return nn.ReLU()
32
- elif act == 'leaky_relu':
33
- return nn.LeakyReLU()
34
- elif act == 'gelu':
35
- return nn.GELU()
36
- elif act == 'swish':
37
- return nn.SiLU()
38
- else:
39
- raise ValueError(f'Activation {act} not supported')
40
-
41
- def get_weight_norm(layer):
42
- return torch.nn.utils.parametrizations.weight_norm(layer)
43
-
44
- def get_layer(layer, norm: bool):
45
- if norm:
46
- return get_weight_norm(layer)
47
- else:
48
- return layer
49
-
50
- class PositionalEncoding(nn.Module):
51
- def __init__(self, dim):
52
- super(PositionalEncoding, self).__init__()
53
- self.dim = dim
54
-
55
- def forward(self, x):
56
- shape = x.shape
57
- x = x * 100
58
- w = torch.pow(10000, (2 * torch.arange(self.dim // 2).float() / self.dim)).to(x)
59
- x = x.unsqueeze(-1) / w
60
- embed = torch.cat([torch.cos(x), torch.sin(x)], -1)
61
- embed = embed.reshape(*shape, -1)
62
- if len(shape) == 2: # f0 embedding, else time embedding
63
- embed = embed.permute(0, 2, 1)
64
- return embed
65
-
66
- class ConvBlock(nn.Module):
67
- def __init__(self,
68
- inp_dim,
69
- out_dim,
70
- kernel_size: int = 3,
71
- stride: int = 1,
72
- padding: Union[str, int] = "same",
73
- norm: bool = True,
74
- nonlinearity: Optional[str] = None,
75
- up: bool = False,
76
- dropout: float = 0.0,
77
- ):
78
- super(ConvBlock, self).__init__()
79
- self.inp_dim = inp_dim
80
- self.out_dim = out_dim
81
- # self.norm = norm
82
- # pdb.set_trace()
83
- if nonlinearity is not None:
84
- self.nonlinearity = get_activation(nonlinearity)
85
- else:
86
- self.nonlinearity = None
87
- if up:
88
- self.conv = get_layer(nn.ConvTranspose1d(inp_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding), norm)
89
- else:
90
- self.conv = get_layer(nn.Conv1d(inp_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding), norm)
91
-
92
- self.layers = nn.ModuleList()
93
- if self.nonlinearity is not None:
94
- self.layers.append(self.nonlinearity)
95
- if dropout > 0:
96
- self.layers.append(nn.Dropout(dropout))
97
- self.layers.append(self.conv)
98
-
99
- def forward(self, x):
100
- for layer in self.layers:
101
- x = layer(x)
102
- return x
103
- class UpSampleLayer(nn.Module):
104
- def __init__(self,
105
- inp_dim,
106
- out_dim,
107
- kernel_size: int = 3,
108
- stride: int = 1,
109
- padding: Union[str, int] = "same",
110
- num_convs: int = 2,
111
- norm: bool = True,
112
- nonlinearity: Optional[str] = None,
113
- dropout: float = 0.0,
114
- ):
115
- super(UpSampleLayer, self).__init__()
116
- assert num_convs > 0, "Number of convolutions must be greater than 0"
117
- self.num_convs = num_convs
118
-
119
- self.convs = nn.ModuleList([])
120
-
121
- self.convs.append(ConvBlock(inp_dim, out_dim, kernel_size=stride*2, stride=stride, padding=padding, norm=norm, nonlinearity=nonlinearity, up=True)) # first convolutional layer to upsample
122
- for ind in range(1, num_convs):
123
- self.convs.append(ConvBlock(out_dim, out_dim, kernel_size=kernel_size, stride=1, padding="same", norm=norm, nonlinearity=nonlinearity, up=False, dropout=dropout if ind == num_convs-1 else 0))
124
-
125
- def forward(self, x):
126
- for conv in self.convs:
127
- x = conv(x)
128
- return x
129
-
130
- class DownSampleLayer(nn.Module):
131
- def __init__(self,
132
- inp_dim,
133
- out_dim,
134
- kernel_size: int = 3,
135
- stride: int = 1,
136
- padding: Union[str, int] = "same",
137
- num_convs: int = 2,
138
- norm: bool = True,
139
- nonlinearity: Optional[str] = None,
140
- dropout: float = 0.0,
141
- ):
142
- super(DownSampleLayer, self).__init__()
143
- assert num_convs > 0, "Number of convolutions must be greater than 0"
144
- self.num_convs = num_convs
145
-
146
- self.convs = nn.ModuleList([])
147
-
148
- self.convs.append(ConvBlock(inp_dim, out_dim, kernel_size=stride*2, stride=stride, padding=padding, norm=norm, nonlinearity=nonlinearity, up=False)) # first convolutional layer to upsample
149
- for ind in range(1, num_convs):
150
- self.convs.append(ConvBlock(out_dim, out_dim, kernel_size=kernel_size, stride=1, padding="same", norm=norm, nonlinearity=nonlinearity, up=False, dropout=dropout if ind == num_convs-1 else 0))
151
-
152
- def forward(self, x):
153
- for conv in self.convs:
154
- x = conv(x)
155
- return x
156
-
157
- # class Attention(nn.Module):
158
- # def __init__(self,
159
- # num_heads,
160
- # num_channels,
161
- # dropout=0.0):
162
- # super(Attention, self).__init__()
163
- # self.num_heads = num_heads
164
- # self.num_channels = num_channels
165
- # self.layer_norm1 = nn.LayerNorm(self.num_channels)
166
- # self.layer_norm2 = nn.LayerNorm(self.num_channels)
167
- # self.qkv_proj = nn.Linear(self.num_channels, self.num_channels * 3, bias=False)
168
- # self.head_dim = self.num_channels // self.num_heads
169
- # self.final_proj = nn.Linear(self.num_channels, self.num_channels)
170
- # self.dropout = nn.Dropout(dropout)
171
-
172
- # def split_heads(self, x):
173
- # # input shape bs, time, channels
174
- # x = x.view(x.shape[0], x.shape[1], self.num_heads, self.head_dim)
175
- # return x.permute(0, 2, 1, 3) # bs, num_heads, time, head_dim
176
-
177
- # def forward(self, x):
178
- # # pdb.set_trace()
179
- # x = torch.permute(x, (0, 2, 1)) # bs, time, channels
180
- # residual = x
181
- # x = self.layer_norm1(x)
182
- # x = self.qkv_proj(x)
183
- # q, k, v = x.chunk(3, dim=-1)
184
-
185
- # # split heads
186
- # q = self.split_heads(q)
187
- # k = self.split_heads(k)
188
- # v = self.split_heads(v)
189
-
190
- # # calculate attention
191
- # x = torch.einsum("...td,...sd->...ts", q, k) / math.sqrt(self.head_dim)
192
- # x = self.dropout(x)
193
- # x = torch.einsum("...ts,...sd->...td", F.softmax(x, dim=-1), v) # bs, num_heads, time, head_dim
194
-
195
- # # combine heads
196
- # x = torch.permute(x, (0, 2, 1, 3)) # bs, time, num_heads, head_dim
197
- # x = x.reshape(x.shape[0], x.shape[1], self.num_heads * self.head_dim)
198
-
199
- # # final projection
200
- # x = self.final_proj(x)
201
- # x = self.layer_norm2(x + residual)
202
- # return torch.permute(x, (0, 2, 1)) # bs, channels, time
203
-
204
- class ResNetBlock(nn.Module):
205
- def __init__(self,
206
- in_channels: int,
207
- out_channels: int,
208
- dropout: float = 0.0,
209
- nonlinearity: Optional[str] = None,
210
- kernel_size: int = 3,
211
- stride: int = 1,
212
- norm: bool = True,
213
- up: bool = False,
214
- num_convs: int = 2,
215
- ):
216
- super(ResNetBlock, self).__init__()
217
-
218
- self.input_layers = nn.ModuleList([])
219
- if nonlinearity is not None:
220
- self.input_layers.append(get_activation(nonlinearity))
221
-
222
- if up:
223
- self.input_layers.append(get_layer(nn.ConvTranspose1d(in_channels, out_channels, kernel_size=stride*2, stride=stride, padding=stride//2), norm))
224
- else:
225
- if in_channels != out_channels:
226
- self.input_layers.append(get_layer(nn.Conv1d(in_channels, out_channels, kernel_size=stride*2, stride=stride, padding=stride//2), norm))
227
- elif stride > 1:
228
- self.input_layers.append(nn.AvgPool1d(stride*2, stride=stride, padding=stride//2))
229
- else:
230
- self.input_layers.append(nn.Identity())
231
-
232
- if up:
233
- self.process_layer = UpSampleLayer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=stride//2, num_convs=num_convs, norm=norm, nonlinearity=nonlinearity, dropout=dropout)
234
- else:
235
- self.process_layer = DownSampleLayer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=stride//2, num_convs=num_convs, norm=norm, nonlinearity=nonlinearity, dropout=dropout)
236
-
237
- def forward(self, x):
238
- # pdb.set_trace()
239
- inputs = x.clone()
240
- for layer in self.input_layers:
241
- inputs = layer(inputs)
242
- x = self.process_layer(x)
243
- return x + inputs
244
-
245
- @gin.configurable
246
- class UNetBase(pl.LightningModule):
247
- def __init__(self, log_grad_norms_every=10):
248
- super(UNetBase, self).__init__()
249
- self.log_grad_norms_every = log_grad_norms_every
250
-
251
- @gin.configurable
252
- def configure_optimizers(self, optimizer_cls: Callable[[], torch.optim.Optimizer],
253
- scheduler_cls: Callable[[],
254
- torch.optim.lr_scheduler._LRScheduler]):
255
- # pdb.set_trace()
256
- optimizer = optimizer_cls(self.parameters())
257
- scheduler = scheduler_cls(optimizer)
258
-
259
- return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]
260
-
261
- @gin.configurable
262
- class UNet(UNetBase):
263
- def __init__(self,
264
- inp_dim,
265
- time_dim,
266
- features,
267
- strides,
268
- kernel_size,
269
- seq_len,
270
- project_dim=None,
271
- dropout=0.0,
272
- nonlinearity=None,
273
- norm=True,
274
- num_convs=2,
275
- num_attns=2,
276
- num_heads=8,
277
- log_samples_every=10,
278
- ckpt=None,
279
- loss_w_padding=False,
280
- groups=None,
281
- nfft=None,
282
- log_grad_norms_every=10
283
- ):
284
- super(UNet, self).__init__()
285
- self.time_dim = time_dim
286
- self.features = features
287
- self.strides = strides
288
- self.kernel_size = kernel_size
289
- self.seq_len = seq_len
290
- self.log_samples_every = log_samples_every
291
- self.ckpt = ckpt
292
- self.strides_prod = np.prod(strides)
293
- self.loss_w_padding = loss_w_padding
294
-
295
- if log_grad_norms_every is not None:
296
- assert log_grad_norms_every > 0, "log_grad_norms_every must be greater than 0"
297
- self.log_grad_norms_every = log_grad_norms_every
298
-
299
- if project_dim is None:
300
- project_dim = features[0]
301
- self.initial_projection = nn.Conv1d(inp_dim, project_dim, kernel_size=1)
302
- self.positional_encoding = PositionalEncoding(time_dim)
303
-
304
- features = [project_dim] + features
305
- strides = [None] + strides
306
-
307
- self.downsample_layers = nn.ModuleList([
308
- ResNetBlock(features[ind-1] + time_dim,
309
- features[ind],
310
- kernel_size=kernel_size,
311
- stride=strides[ind],
312
- dropout=dropout,
313
- nonlinearity=nonlinearity,
314
- norm=norm,
315
- num_convs=num_convs,
316
- ) for ind in range(1, len(features))
317
- ])
318
-
319
- # self.attention_layers = nn.ModuleList(
320
- # [Attention(num_heads=num_heads, num_channels=features[-1], dropout=dropout) for _ in range(num_attns)]
321
- # )
322
-
323
- self.attention_layers = AttentionLayers(
324
- dim = features[-1],
325
- heads = num_heads,
326
- depth = num_attns,
327
- )
328
-
329
- self.upsample_layers = nn.ModuleList([
330
- ResNetBlock(features[ind] * 2 + time_dim, # input size defined by features + skip dimension + time dimension
331
- features[ind-1],
332
- kernel_size=kernel_size,
333
- stride=strides[ind],
334
- dropout=dropout,
335
- nonlinearity=nonlinearity,
336
- norm=norm,
337
- num_convs=num_convs,
338
- up=True
339
- ) for ind in range(len(features) - 1, 0, -1)
340
- ])
341
- self.final_projection = nn.Conv1d(2*project_dim, inp_dim, kernel_size=1)
342
-
343
- def pad_to(self, x, strides):
344
- # modified from: https://stackoverflow.com/questions/66028743/how-to-handle-odd-resolutions-in-unet-architecture-pytorch
345
- l = x.shape[-1]
346
-
347
- if l % strides > 0:
348
- new_l = l + strides - l % strides
349
- else:
350
- new_l = l
351
-
352
- ll, ul = int((new_l-l) / 2), int(new_l-l) - int((new_l-l) / 2)
353
- pads = (ll, ul)
354
-
355
- out = F.pad(x, pads, "reflect").to(x)
356
-
357
- return out, pads
358
-
359
- def unpad(self, x, pad):
360
- # modified from: https://stackoverflow.com/questions/66028743/how-to-handle-odd-resolutions-in-unet-architecture-pytorch
361
- if pad[0]+pad[1] > 0:
362
- x = x[:,:,pad[0]:-pad[1]]
363
- return x
364
-
365
- def forward(self, x, time):
366
-
367
- # INITIAL PROJECTION
368
- x = self.initial_projection(x)
369
-
370
- # TIME CONDITIONING
371
- time = self.positional_encoding(time)
372
-
373
- def _concat_time(x, time):
374
- time = time.unsqueeze(2).expand(-1, -1, x.shape[-1])
375
- x = torch.cat([x, time], -2)
376
- return x
377
-
378
- skips = []
379
-
380
- # DOWNSAMPLING
381
- for ind, downsample_layer in enumerate(self.downsample_layers):
382
- # print(f'Down sample layer {ind}')
383
- skips.append(x)
384
- x = _concat_time(x, time)
385
- x = downsample_layer(x)
386
- skips.append(x)
387
-
388
- # BOTTLENECK ATTENTION
389
- x = torch.permute(x, (0, 2, 1))
390
- x = self.attention_layers(x)
391
- x = torch.permute(x, (0, 2, 1))
392
- # pdb.set_trace()
393
- # UPSAMPLING
394
- for ind, upsample_layer in enumerate(self.upsample_layers):
395
- # print(f'Up sample layer {ind}')
396
- x = _concat_time(x, time)
397
- x = torch.cat([x, skips.pop(-1)], 1)
398
- x = upsample_layer(x)
399
- x = torch.cat([x, skips.pop(-1)], 1)
400
-
401
- # FINAL PROJECTION
402
- x = self.final_projection(x)
403
- return x
404
-
405
- def loss(self, x):
406
- # pdb.set_trace()
407
- padded_x, padding = self.pad_to(x, self.strides_prod)
408
- t = torch.rand((padded_x.shape[0],)).to(padded_x)
409
- noise = torch.normal(0, 1, padded_x.shape).to(padded_x)
410
- # print(t.device, noise.device, x.device)
411
- x_t = t[:, None, None] * padded_x + (1 - t[:, None, None]) * noise
412
- # print(t.device, noise.device, x_t.device, x.device)
413
- padded_y = self.forward(x_t, t)
414
- unpadded_y = self.unpad(padded_y, padding)
415
-
416
- if self.loss_w_padding:
417
- target = padded_x - noise
418
- return torch.mean((padded_y - target) ** 2)
419
- else:
420
- target = x - self.unpad(noise, padding) # x1 - x0
421
- return torch.mean((unpadded_y - target) ** 2)
422
-
423
-
424
- def on_before_optimizer_step(self, optimizer, *_):
425
- def calculate_grad_norm(module_list, norm_type=2):
426
- total_norm = 0
427
- if isinstance(module_list, nn.Module):
428
- module_list = [module_list]
429
- for module in module_list:
430
- for name, param in module.named_parameters():
431
- if param.requires_grad:
432
- param_norm = torch.norm(param.grad.detach(), p=norm_type)
433
- total_norm += param_norm**2
434
- # pdb.set_trace()
435
- total_norm = torch.sqrt(total_norm)
436
- return total_norm
437
-
438
- if self.log_grad_norms_every is not None and self.global_step % self.log_grad_norms_every == 0:
439
- self.log('Grad Norm/Downsample Layers', calculate_grad_norm(self.downsample_layers))
440
- self.log('Grad Norm/Attention Layers', calculate_grad_norm(self.attention_layers))
441
- self.log('Grad Norm/Upsample Layers', calculate_grad_norm(self.upsample_layers))
442
-
443
- def training_step(self, batch, batch_idx):
444
- # print('\n', batch_idx, batch.shape)
445
- x = batch
446
- loss = self.loss(x)
447
-
448
- # log grad_norms
449
- # if self.log_grad_norms_every > 0 and self.current_epoch % self.log_grad_norms_every == 0:
450
-
451
- # for ind, attention_layer in enumerate(self.attention_layers):
452
- # self.log(f'Grad Norm/Attention Layer {ind}', grad_norm(attention_layer, norm_type=2))
453
- # for ind, downsample_layer in enumerate(self.downsample_layers):
454
- # self.log(f'Grad Norm/Downsample Layer {ind}', grad_norm(downsample_layer, norm_type=2))
455
-
456
- self.log('train_loss', loss)
457
-
458
- return loss
459
-
460
- def validation_step(self, batch, batch_idx):
461
- x = batch
462
- loss = self.loss(x)
463
- self.log('val_loss', loss)
464
- return loss
465
-
466
- def sample_fn(self, batch_size: int, num_steps: int, prime: Optional[torch.Tensor] = None):
467
- # CREATE INITIAL NOISE
468
- if prime is not None:
469
- prime = prime.to(self.device)
470
- noise = torch.normal(mean=0, std=1, size=(batch_size, 1, self.seq_len)).to(self.device)
471
- x_alpha_t = noise.clone()
472
- t_array = torch.ones((batch_size,)).to(self.device)
473
- # x_alpha_ts = {}
474
- with torch.no_grad():
475
- # SAMPLE FROM MODEL
476
- for t in np.linspace(0, 1, num_steps + 1)[:-1]:
477
- t_tensor = torch.tensor(t)
478
- alpha_t = t_tensor * t_array
479
- alpha_t = alpha_t.unsqueeze(1).unsqueeze(2).to(self.device)
480
- if prime is not None:
481
- x_alpha_t[:, :, :prime.shape[-1]] = ((1 - alpha_t) * noise[:, :, :prime.shape[-1]]) + (alpha_t * prime) # fill in the prime in the beginning of each x_t
482
- diff = self.forward(x_alpha_t, t_tensor * t_array)
483
- x_alpha_t = x_alpha_t + 1 / num_steps * diff
484
- # x_alpha_ts[t] = x_alpha_t
485
- # if prime is not None:
486
- # x_alpha_t[:, :, :prime.shape[-1]] = prime
487
- return x_alpha_t
488
-
489
- def sample_sdedit(self, cond, batch_size, num_steps, t0=0.5):
490
- # pdb.set_trace()
491
- t0_steps = int(t0*num_steps)
492
- # iterate to get x0
493
- t_array = torch.ones((batch_size,)).to(self.device)
494
- x_alpha_t = cond.clone()
495
- with torch.no_grad():
496
- for t in np.linspace(t0, 0, t0_steps + 1)[:-1]:
497
- t_tensor = torch.tensor(t)
498
- x_alpha_t = x_alpha_t - (1 / num_steps) * self.forward(x_alpha_t, t_tensor * t_array)
499
- # x_alpha_t is x0 now
500
- # iterate to get x1
501
- for t in np.linspace(0, 1, num_steps + 1)[:-1]:
502
- t_tensor = torch.tensor(t)
503
- # print(unet.device, noise.device, t_tensor.device, t_array.device)
504
- x_alpha_t = x_alpha_t + 1 / num_steps * self.forward(x_alpha_t, t_tensor * t_array)
505
-
506
- return x_alpha_t
507
-
508
-
509
-
510
- def on_validation_epoch_end(self) -> None:
511
- if self.current_epoch % self.log_samples_every == 0:
512
- samples = self.sample_fn(16, 100).detach().cpu().numpy()
513
- if self.ckpt is not None:
514
- os.makedirs(os.path.join(self.ckpt, 'samples', str(self.current_epoch)), exist_ok=True)
515
- fig, axs = plt.subplots(4, 4, figsize=(16, 16))
516
- for i in range(4):
517
- for j in range(4):
518
- axs[i, j].plot(samples[i*4+j].squeeze())
519
- pd.DataFrame(samples[i*4+j].squeeze(), columns=['normalized_pitch']).to_csv(os.path.join(self.ckpt, 'samples', str(self.current_epoch), f'sample_{i*4+j}.csv'))
520
- if self.logger:
521
- wandb.log({"samples": [wandb.Image(fig, caption="Samples")]})
522
- else:
523
- fig.savefig(os.path.join(self.ckpt, 'samples', str(self.current_epoch), 'samples.png'))
524
- plt.close(fig)
525
-
526
-
527
- @gin.configurable
528
- class UNetAudio(UNetBase):
529
- def __init__(self,
530
- inp_dim,
531
- time_dim,
532
- features,
533
- strides,
534
- kernel_size,
535
- seq_len,
536
- project_dim=None,
537
- dropout=0.0,
538
- nonlinearity=None,
539
- norm=True,
540
- num_convs=2,
541
- num_attns=2,
542
- num_heads=8,
543
- ckpt=None,
544
- qt = None,
545
- log_samples_every = 10,
546
- log_wandb_samples_every = 50,
547
- sr=16000,
548
- loss_w_padding=False,
549
- log_grad_norms_every=10
550
- ):
551
- super(UNetAudio, self).__init__()
552
- self.inp_dim = inp_dim
553
- self.time_dim = time_dim
554
- self.features = features
555
- self.strides = strides
556
- self.kernel_size = kernel_size
557
- self.seq_len = seq_len
558
- self.log_samples_every = log_samples_every
559
- self.log_wandb_samples_every = log_wandb_samples_every
560
- self.ckpt = ckpt
561
- self.qt = qt
562
- self.sr = sr
563
- self.strides_prod = np.prod(strides)
564
- self.loss_w_padding = loss_w_padding
565
- self.log_grad_norms_every = log_grad_norms_every
566
-
567
- if project_dim is None:
568
- project_dim = features[0]
569
- self.initial_projection = nn.Conv1d(inp_dim, project_dim, kernel_size=1)
570
- self.positional_encoding = PositionalEncoding(time_dim)
571
-
572
- features = [project_dim] + features
573
- strides = [None] + strides
574
-
575
- self.downsample_layers = nn.ModuleList([
576
- ResNetBlock(features[ind-1] + time_dim,
577
- features[ind],
578
- kernel_size=kernel_size,
579
- stride=strides[ind],
580
- dropout=dropout,
581
- nonlinearity=nonlinearity,
582
- norm=norm,
583
- num_convs=num_convs,
584
- ) for ind in range(1, len(features))
585
- ])
586
-
587
- self.attention_layers = AttentionLayers(
588
- dim = features[-1],
589
- heads = num_heads,
590
- depth = num_attns,
591
- )
592
-
593
- self.upsample_layers = nn.ModuleList([
594
- ResNetBlock(features[ind] * 2 + time_dim, # input size defined by features + skip dimension + time dimension
595
- features[ind-1],
596
- kernel_size=kernel_size,
597
- stride=strides[ind],
598
- dropout=dropout,
599
- nonlinearity=nonlinearity,
600
- norm=norm,
601
- num_convs=num_convs,
602
- up=True
603
- ) for ind in range(len(features) - 1, 0, -1)
604
- ])
605
- self.final_projection = nn.Conv1d(2*project_dim, inp_dim, kernel_size=1)
606
- self.losses = []
607
-
608
- def forward(self, x, time):
609
- # INITIAL PROJECTION
610
- x = self.initial_projection(x)
611
-
612
- # TIME CONDITIONING
613
- time = self.positional_encoding(time)
614
-
615
- def _concat_time(x, time):
616
- time = time.unsqueeze(2).expand(-1, -1, x.shape[-1])
617
- x = torch.cat([x, time], -2)
618
- return x
619
-
620
- skips = []
621
-
622
- # DOWNSAMPLING
623
- for ind, downsample_layer in enumerate(self.downsample_layers):
624
- # print(f'Down sample layer {ind}')
625
- skips.append(x)
626
- x = _concat_time(x, time)
627
- x = downsample_layer(x)
628
- skips.append(x)
629
- # BOTTLENECK ATTENTION
630
- x = torch.permute(x, (0, 2, 1))
631
- x = self.attention_layers(x)
632
- x = torch.permute(x, (0, 2, 1))
633
-
634
- # pdb.set_trace()
635
- # UPSAMPLING
636
- for ind, upsample_layer in enumerate(self.upsample_layers):
637
- # print(f'Up sample layer {ind}')
638
- x = _concat_time(x, time)
639
- x = torch.cat([x, skips.pop(-1)], 1)
640
- x = upsample_layer(x)
641
- x = torch.cat([x, skips.pop(-1)], 1)
642
-
643
- # FINAL PROJECTION
644
- x = self.final_projection(x)
645
- return x
646
-
647
- def pad_to(self, x, strides):
648
- # modified from: https://stackoverflow.com/questions/66028743/how-to-handle-odd-resolutions-in-unet-architecture-pytorch
649
- l = x.shape[-1]
650
-
651
- if l % strides > 0:
652
- new_l = l + strides - l % strides
653
- else:
654
- new_l = l
655
-
656
- ll, ul = int((new_l-l) / 2), int(new_l-l) - int((new_l-l) / 2)
657
- pads = (ll, ul)
658
-
659
- out = F.pad(x, pads, "reflect").to(x)
660
-
661
- return out, pads
662
-
663
- def unpad(self, x, pad):
664
- # modified from: https://stackoverflow.com/questions/66028743/how-to-handle-odd-resolutions-in-unet-architecture-pytorch
665
- if pad[0]+pad[1] > 0:
666
- x = x[:,:,pad[0]:-pad[1]]
667
- return x
668
-
669
- def loss(self, x):
670
- padded_x, padding = self.pad_to(x, self.strides_prod)
671
- t = torch.rand((padded_x.shape[0],)).to(padded_x)
672
- noise = torch.normal(0, 1, padded_x.shape).to(padded_x)
673
- # print(t.device, noise.device, x.device)
674
- x_t = t[:, None, None] * padded_x + (1 - t[:, None, None]) * noise
675
- # print(t.device, noise.device, x_t.device, x.device)
676
- padded_y = self.forward(x_t, t)
677
- unpadded_y = self.unpad(padded_y, padding)
678
-
679
- if self.loss_w_padding:
680
- target = padded_x - noise
681
- return torch.mean((padded_y - target) ** 2)
682
- else:
683
- target = x - self.unpad(noise, padding) # x1 - x0
684
- return torch.mean((unpadded_y - target) ** 2)
685
-
686
- def training_step(self, batch, batch_idx):
687
- # print('\n', batch_idx, batch.shape)
688
- x = batch
689
- loss = self.loss(x)
690
- self.log('train_loss', loss)
691
- return loss
692
-
693
- def validation_step(self, batch, batch_idx):
694
- x = batch
695
- loss = self.loss(x)
696
- self.log('val_loss', loss)
697
- return loss
698
-
699
- def sample_fn(self, batch_size: int, num_steps: int, prime=None):
700
- if prime is not None:
701
- prime = prime.to(self.device)
702
- # CREATE INITIAL NOISE
703
- noise = torch.normal(mean=0, std=1, size=(batch_size, self.inp_dim, self.seq_len)).to(self.device)
704
- padded_noise, padding = self.pad_to(noise, self.strides_prod)
705
- x_alpha_t = padded_noise.clone()
706
- t_array = torch.ones((batch_size,)).to(self.device)
707
- with torch.no_grad():
708
- # SAMPLE FROM MODEL
709
- for t in np.linspace(0, 1, num_steps + 1)[:-1]:
710
- t_tensor = torch.tensor(t)
711
- alpha_t = t_tensor * t_array
712
- alpha_t = alpha_t.unsqueeze(1).unsqueeze(2).to(self.device)
713
- if prime is not None:
714
- x_alpha_t[:, :, :prime.shape[-1]] = ((1 - alpha_t) * noise[:, :, :prime.shape[-1]]) + (alpha_t * prime) # fill in the prime in the beginning of each x_t
715
- diff = self.forward(x_alpha_t, t_tensor * t_array)
716
- x_alpha_t = x_alpha_t + 1 / num_steps * diff
717
-
718
- padded_y = x_alpha_t
719
- unpadded_y = self.unpad(padded_y, padding)
720
-
721
- return unpadded_y
722
-
723
- def on_validation_epoch_end(self) -> None:
724
- if self.current_epoch % self.log_samples_every == 0:
725
- if self.ckpt is not None:
726
- os.makedirs(os.path.join(self.ckpt, 'samples', str(self.current_epoch)), exist_ok=True)
727
- samples = self.sample_fn(16, 100)
728
- audio = p2a.normalized_mels_to_audio(samples, qt=self.qt)
729
- beep = torch.sin(2 * torch.pi * 220 * torch.arange(0, 0.1 * self.sr) / self.sr).to(audio)
730
- concat_audios = []
731
- for sample in audio:
732
- concat_audios.append(torch.cat([sample, beep]))
733
- concat_audio = torch.cat(concat_audios, dim=-1).reshape(1, -1).to('cpu')
734
- output_file = os.path.join(self.ckpt, 'samples', f'samples_{self.current_epoch}.wav')
735
- torchaudio.save(output_file, concat_audio, self.sr)
736
- if self.current_epoch % self.log_wandb_samples_every == 0:
737
- if self.logger:
738
- wandb.log({
739
- "samples": [wandb.Audio(output_file, self.sr, caption="Samples")]})
740
-
741
- def on_before_optimizer_step(self, optimizer, *_):
742
- def calculate_grad_norm(module_list, norm_type=2):
743
- total_norm = 0
744
- if isinstance(module_list, nn.Module):
745
- module_list = [module_list]
746
- for module in module_list:
747
- for name, param in module.named_parameters():
748
- if param.requires_grad:
749
- param_norm = torch.norm(param.grad.detach(), p=norm_type)
750
- total_norm += param_norm**2
751
- # pdb.set_trace()
752
- total_norm = torch.sqrt(total_norm)
753
- return total_norm
754
-
755
- if self.log_grad_norms_every is not None and self.global_step % self.log_grad_norms_every == 0:
756
- self.log('Grad Norm/Downsample Layers', calculate_grad_norm(self.downsample_layers))
757
- self.log('Grad Norm/Attention Layers', calculate_grad_norm(self.attention_layers))
758
- self.log('Grad Norm/Upsample Layers', calculate_grad_norm(self.upsample_layers))
759
- # def configure_optimizers(self):
760
- # return optim.Adam(self.parameters(), lr=1e-4)
761
-
762
- @gin.configurable
763
- class UNetPitchConditioned(UNetBase):
764
- def __init__(self,
765
- inp_dim,
766
- time_dim,
767
- f0_dim,
768
- features,
769
- strides,
770
- kernel_size,
771
- audio_seq_len,
772
- project_dim=None,
773
- dropout=0.0,
774
- nonlinearity=None,
775
- norm=True,
776
- num_convs=2,
777
- num_attns=2,
778
- num_heads=8,
779
- log_samples_every=10,
780
- log_wandb_samples_every=10,
781
- ckpt=None,
782
- val_data=None,
783
- qt=None,
784
- singer_conditioning=False,
785
- singer_dim=128,
786
- singer_vocab=56,
787
- sr = 44100,
788
- cfg = False,
789
- f0_mask = 0,
790
- cond_drop_prob = 0.0,
791
- groups = None,
792
- nfft = None,
793
- loss_w_padding = False,
794
- log_grad_norms_every=10
795
- ):
796
- super(UNetPitchConditioned, self).__init__()
797
- self.inp_dim = inp_dim
798
- self.time_dim = time_dim
799
- self.features = features
800
- self.strides = strides
801
- self.kernel_size = kernel_size
802
- self.seq_len = audio_seq_len
803
- self.log_samples_every = log_samples_every
804
- self.log_wandb_samples_every = log_wandb_samples_every
805
- self.ckpt = ckpt
806
- self.qt = qt
807
- self.singer_conditioning = singer_conditioning
808
- self.sr = sr # used for logging audio to wandb
809
- self.cond_drop_prob = cond_drop_prob
810
- self.f0_masked_token = f0_mask
811
- self.cfg = cfg
812
- self.strides_prod = np.prod(strides)
813
- self.loss_w_padding = loss_w_padding
814
- self.log_grad_norms_every = log_grad_norms_every
815
-
816
- conditioning_dim = time_dim
817
- if singer_conditioning:
818
- conditioning_dim += singer_dim
819
-
820
- if project_dim is None:
821
- project_dim = features[0]
822
- self.initial_projection = nn.Conv1d(inp_dim, project_dim, kernel_size=1)
823
- self.time_positional_encoding = PositionalEncoding(time_dim)
824
- self.f0_positional_encoding = PositionalEncoding(f0_dim)
825
-
826
- if singer_conditioning:
827
- self.singer_embedding = nn.Embedding(singer_vocab + 1*self.cfg, singer_dim) # if cfg, add 1 to the singer vocabulary
828
- self.singer_masked_token = singer_vocab
829
- else:
830
- self.singer_embedding = None
831
-
832
- features = [project_dim] + features
833
- f0_features = features.copy()
834
- f0_features[0] = f0_dim # first layer should be the f0 dimension
835
- strides = [None] + strides
836
-
837
- self.downsample_layers = nn.ModuleList([
838
- ResNetBlock(features[ind-1] + conditioning_dim,
839
- features[ind],
840
- kernel_size=kernel_size,
841
- stride=strides[ind],
842
- dropout=dropout,
843
- nonlinearity=nonlinearity,
844
- norm=norm,
845
- num_convs=num_convs,
846
- ) for ind in range(1, len(features))
847
- ])
848
-
849
- self.f0_conv_layers = nn.ModuleList([
850
- nn.Conv1d(
851
- f0_dim,
852
- f0_dim,
853
- kernel_size=2 * strides[ind],
854
- stride=strides[ind],
855
- padding=strides[ind]//2,
856
- ) for ind in range(1, len(features))
857
- ])
858
-
859
- self.attention_layers = AttentionLayers(
860
- dim = features[-1],
861
- heads = num_heads,
862
- depth = num_attns,
863
- )
864
-
865
- self.upsample_layers = nn.ModuleList([
866
- ResNetBlock((features[ind] * 2) + (conditioning_dim) + f0_dim, # input size defined by features + skip dimension + time dimension
867
- features[ind-1],
868
- kernel_size=kernel_size,
869
- stride=strides[ind],
870
- dropout=dropout,
871
- nonlinearity=nonlinearity,
872
- norm=norm,
873
- num_convs=num_convs,
874
- up=True
875
- ) for ind in range(len(features) - 1, 0, -1)
876
- ])
877
- self.final_projection = nn.Conv1d(2*project_dim + f0_dim, inp_dim, kernel_size=1)
878
- # save 16 f0 values from to sample on
879
- if val_data is not None:
880
- val_ids = np.random.choice(len(val_data), 16)
881
- val_samples = [val_data[i] for i in val_ids]
882
- self.val_f0 = torch.stack([v[1] for v in val_samples], 0).to(self.device)
883
- if self.singer_conditioning:
884
- self.val_singer = torch.tensor([v[2] for v in val_samples]).long().to(self.device)
885
- else:
886
- self.val_singer = None
887
- val_audio = torch.stack([v[0] for v in val_samples], 0).to(self.device)
888
- if self.ckpt is not None:
889
- # log the f0 and audio to wandb
890
- os.makedirs(os.path.join(self.ckpt, 'samples'), exist_ok=True)
891
- concat_audios = []
892
- beep = torch.sin(2 * torch.pi * 220 * torch.arange(0, 0.1 * self.sr) / self.sr).to(val_audio)
893
- recon_audios = p2a.normalized_mels_to_audio(val_audio, qt=self.qt)
894
- fig, axs = plt.subplots(4, 4, figsize=(16, 16))
895
- for i in range(4):
896
- for j in range(4):
897
- axs[i, j].plot(self.val_f0[i*4+j].squeeze())
898
- if self.singer_conditioning:
899
- axs[i, j].set_title(f'Singer {self.val_singer[i*4+j].item()}')
900
- concat_audios.append(torch.cat((recon_audios[i*4+j].squeeze(), beep)))
901
- concat_audios = torch.cat(concat_audios, dim=-1).reshape(1, -1).to('cpu')
902
- output_file = os.path.join(self.ckpt, 'samples', f'gt_samples.wav')
903
- torchaudio.save(output_file, concat_audios, self.sr)
904
-
905
- try:
906
- wandb.log({"sample f0 input": [wandb.Image(fig, caption="f0 conditioning on samples")]})
907
- wandb.log({
908
- "sample audio ground truth": [wandb.Audio(output_file, self.sr, caption="Samples")]})
909
- except:
910
- pass
911
-
912
- fig.savefig(os.path.join(self.ckpt, 'samples', 'f0_inputs.png'))
913
-
914
- def pad_to(self, x, strides):
915
- # modified from: https://stackoverflow.com/questions/66028743/how-to-handle-odd-resolutions-in-unet-architecture-pytorch
916
- l = x.shape[-1]
917
-
918
- if l % strides > 0:
919
- new_l = l + strides - l % strides
920
- else:
921
- new_l = l
922
-
923
- ll, ul = int((new_l-l) / 2), int(new_l-l) - int((new_l-l) / 2)
924
- pads = (ll, ul)
925
-
926
- out = F.pad(x, pads, "reflect").to(x)
927
-
928
- return out, pads
929
-
930
- def unpad(self, x, pad):
931
- # modified from: https://stackoverflow.com/questions/66028743/how-to-handle-odd-resolutions-in-unet-architecture-pytorch
932
- if pad[0]+pad[1] > 0:
933
- x = x[:,:,pad[0]:-pad[1]]
934
- return x
935
-
936
- def forward(self, x, time, f0, singer, drop_tokens=True, drop_all=False):
937
- # INITIAL PROJECTION
938
- x = self.initial_projection(x)
939
-
940
- bs = x.shape[0]
941
- if self.cfg:
942
- # pdb.set_trace()
943
- if drop_all:
944
- prob_keep_mask_pitch = torch.zeros((bs)).unsqueeze(1).repeat(1, f0.shape[1]).to(self.device).bool()
945
- prob_keep_mask_singer = torch.zeros((bs)).to(self.device).bool()
946
- elif drop_tokens:
947
- prob_keep_mask_pitch = prob_mask_like((bs), 1. - self.cond_drop_prob, device = self.device).unsqueeze(1).repeat(1, f0.shape[1])
948
- prob_keep_mask_singer = prob_mask_like((bs), 1. - self.cond_drop_prob, device = self.device)
949
- else:
950
- prob_keep_mask_pitch = torch.ones((bs)).unsqueeze(1).repeat(1, f0.shape[1]).to(self.device).bool()
951
- prob_keep_mask_singer = torch.ones((bs)).to(self.device).bool()
952
- f0 = torch.where(prob_keep_mask_pitch, f0, torch.empty((f0.shape[0], f0.shape[1])).fill_(self.f0_masked_token).to(self.device).long())
953
- if self.singer_conditioning:
954
- singer = torch.where(prob_keep_mask_singer, singer, torch.empty((bs)).fill_(self.singer_masked_token).to(self.device).long())
955
-
956
- # TIME and F0 CONDITIONING
957
- conditions = [self.time_positional_encoding(time)]
958
- if self.singer_conditioning:
959
- conditions.append(self.singer_embedding(singer))
960
- f0 = self.f0_positional_encoding(f0)
961
-
962
- def _concat_condition(x, condition):
963
- condition = condition.unsqueeze(2).expand(-1, -1, x.shape[-1])
964
- x = torch.cat([x, condition], -2)
965
- return x
966
-
967
- skips = []
968
-
969
- # DOWNSAMPLING
970
- # pdb.set_trace()
971
- for ind, downsample_layer in enumerate(self.downsample_layers):
972
- # print(f'Down sample layer {ind}')
973
- # pdb.set_trace()
974
- skips.append(torch.cat([x, f0], -2))
975
- for cond in conditions:
976
- x = _concat_condition(x, cond)
977
- # print(x.shape, time.shape, f0.shape, skips[-1].shape)
978
- x = downsample_layer(x)
979
- f0 = self.f0_conv_layers[ind](f0)
980
- skips.append(torch.cat([x, f0], -2))
981
- # BOTTLENECK ATTENTION
982
- x = torch.permute(x, (0, 2, 1))
983
- x = self.attention_layers(x)
984
- x = torch.permute(x, (0, 2, 1))
985
- # print(x.shape, time.shape, f0.shape, skips[-1].shape)
986
- # pdb.set_trace()
987
- # UPSAMPLING
988
- for ind, upsample_layer in enumerate(self.upsample_layers):
989
- # print(f'Up sample layer {ind}')
990
- for cond in conditions:
991
- x = _concat_condition(x, cond)
992
- x = torch.cat([x, skips.pop(-1)], 1)
993
- # print(x.shape, time.shape, f0.shape)
994
- x = upsample_layer(x)
995
- x = torch.cat([x, skips.pop(-1)], 1)
996
-
997
- # FINAL PROJECTION
998
- x = self.final_projection(x)
999
- return x
1000
-
1001
- def loss(self, x, f0, singer, drop_tokens):
1002
- # pdb.set_trace()
1003
- padded_x, padding = self.pad_to(x, self.strides_prod)
1004
- padded_f0, _ = self.pad_to(f0, self.strides_prod)
1005
- t = torch.rand((padded_x.shape[0],)).to(padded_x)
1006
- noise = torch.normal(0, 1, padded_x.shape).to(padded_x)
1007
- # print(t.device, noise.device, x.device)
1008
- x_t = t[:, None, None] * padded_x + (1 - t[:, None, None]) * noise
1009
- # print(t.device, noise.device, x_t.device, x.device)
1010
- padded_y = self.forward(x_t, t, padded_f0, singer, drop_tokens)
1011
- unpadded_y = self.unpad(padded_y, padding)
1012
-
1013
- if self.loss_w_padding:
1014
- target = padded_x - noise
1015
- return torch.mean((padded_y - target) ** 2)
1016
- else:
1017
- target = x - self.unpad(noise, padding) # x1 - x0
1018
- return torch.mean((unpadded_y - target) ** 2)
1019
-
1020
- def training_step(self, batch, batch_idx):
1021
- # print('\n', batch_idx, batch.shape)
1022
- x, f0, singer = batch
1023
- x = x.to(self.device)
1024
- f0 = f0.to(self.device)
1025
- singer = singer.reshape(-1).long().to(self.device) if self.singer_conditioning else None
1026
- loss = self.loss(x, f0, singer, drop_tokens=True)
1027
- self.log('train_loss', loss, batch_size=x.shape[0])
1028
- return loss
1029
-
1030
- def validation_step(self, batch, batch_idx):
1031
- # pdb.set_trace()
1032
- x, f0, singer = batch
1033
- x = x.to(self.device)
1034
- f0 = f0.to(self.device)
1035
- singer = singer.reshape(-1).long().to(self.device) if self.singer_conditioning else None
1036
- loss = self.loss(x, f0, singer, drop_tokens=False)
1037
- self.log('val_loss', loss, batch_size=x.shape[0])
1038
- return loss
1039
-
1040
- def sample_fn(self, f0, singer, batch_size: int, num_steps: int):
1041
- # CREATE INITIAL NOISE
1042
- noise = torch.normal(mean=0, std=1, size=(batch_size, self.inp_dim, self.seq_len)).to(self.device)
1043
- padded_noise, padding = self.pad_to(noise, self.strides_prod)
1044
- t_array = torch.ones((batch_size,)).to(self.device)
1045
- f0 = f0.to(self.device)
1046
- padded_f0, _ = self.pad_to(f0, self.strides_prod)
1047
- singer = singer.to(self.device)
1048
- with torch.no_grad():
1049
- # SAMPLE FROM MODEL
1050
- for t in np.linspace(0, 1, num_steps + 1)[:-1]:
1051
- t_tensor = torch.tensor(t)
1052
- padded_noise = padded_noise + 1 / num_steps * self.forward(padded_noise, t_tensor * t_array, padded_f0, singer, drop_tokens=False)
1053
- noise = self.unpad(padded_noise, padding)
1054
- return noise
1055
-
1056
- def sample_cfg(self, batch_size: int, num_steps: int, f0=None, singer=[4, 25, 45, 32], strength=1):
1057
- # CREATE INITIAL NOISE
1058
- noise = torch.normal(mean=0, std=1, size=(batch_size, self.inp_dim, self.seq_len)).to(self.device)
1059
- padded_noise, padding = self.pad_to(noise, self.strides_prod)
1060
- t_array = torch.ones((batch_size,)).to(self.device)
1061
- if f0 is None:
1062
- val_idx = np.random.choice(len(self.val_dataloader), batch_size)
1063
- val_samples = [self.val_dataloader[i][1] for i in val_idx]
1064
- f0 = torch.stack([sample for sample in val_samples]).to(self.device)
1065
- else:
1066
- assert len(f0) == batch_size
1067
- f0 = f0.to(self.device)
1068
- singer = singer.to(self.device)
1069
- # f0 = torch.tensor(f0).to(self.device)
1070
- # singer = torch.Tensor(np.choice(singer, batch_size, replace=True)).to(self.device)
1071
- padded_f0, _ = self.pad_to(f0, self.strides_prod)
1072
- with torch.no_grad():
1073
- # SAMPLE FROM MODEL
1074
- for t in np.linspace(0, 1, num_steps + 1)[:-1]:
1075
- t_tensor = torch.tensor(t)
1076
- unconditioned_logits = self.forward(padded_noise, t_tensor * t_array, padded_f0, singer, drop_tokens=False, drop_all=True)
1077
- conditioned_logits = self.forward(padded_noise, t_tensor * t_array, padded_f0, singer, drop_tokens=False, drop_all=False)
1078
- total_logits = strength * conditioned_logits + (1 - strength) * unconditioned_logits
1079
- padded_noise = padded_noise + 1 / num_steps * total_logits
1080
-
1081
- noise = self.unpad(padded_noise, padding)
1082
- return noise, f0, singer
1083
-
1084
- def on_validation_epoch_end(self) -> None:
1085
- with torch.no_grad():
1086
- # pdb.set_trace()
1087
- if self.current_epoch % self.log_samples_every == 0:
1088
- samples = self.sample_fn(self.val_f0, self.val_singer, 16, 100)
1089
- if self.ckpt is not None:
1090
- audio = p2a.normalized_mels_to_audio(samples, qt=self.qt)
1091
- beep = torch.sin(2 * torch.pi * 220 * torch.arange(0, 0.1 * self.sr) / self.sr).to(audio)
1092
- concat_audio = []
1093
- for sample in audio:
1094
- concat_audio.append(torch.cat([sample, beep]))
1095
- concat_audio = torch.cat(concat_audio, dim=-1).reshape(1, -1).to('cpu')
1096
- output_file = os.path.join(self.ckpt, 'samples', f'samples_{self.current_epoch}.wav')
1097
- torchaudio.save(output_file, concat_audio, self.sr)
1098
- if self.current_epoch % self.log_wandb_samples_every == 0:
1099
- if self.logger:
1100
- wandb.log({
1101
- "samples": [wandb.Audio(output_file, self.sr, caption="Samples")]},
1102
- step = self.global_step)
1103
- def on_before_optimizer_step(self, optimizer, *_):
1104
- def calculate_grad_norm(module_list, norm_type=2):
1105
- total_norm = 0
1106
- if isinstance(module_list, nn.Module):
1107
- module_list = [module_list]
1108
- for module in module_list:
1109
- for name, param in module.named_parameters():
1110
- if param.requires_grad:
1111
- param_norm = torch.norm(param.grad.detach(), p=norm_type)
1112
- total_norm += param_norm**2
1113
- # pdb.set_trace()
1114
- total_norm = torch.sqrt(total_norm)
1115
- return total_norm
1116
-
1117
- if self.log_grad_norms_every is not None and self.global_step % self.log_grad_norms_every == 0:
1118
- self.log('Grad Norm/Downsample Layers', calculate_grad_norm(self.downsample_layers))
1119
- self.log('Grad Norm/Attention Layers', calculate_grad_norm(self.attention_layers))
1120
- self.log('Grad Norm/Upsample Layers', calculate_grad_norm(self.upsample_layers))
1121
-
1122
- # @gin.configurable
1123
- # def configure_optimizers(self, optimizer_cls: Callable[[], torch.optim.Optimizer],
1124
- # scheduler_cls: Callable[[],
1125
- # torch.optim.lr_scheduler._LRScheduler]):
1126
- # # pdb.set_trace()
1127
- # optimizer = optimizer_cls(self.parameters())
1128
- # scheduler = scheduler_cls(optimizer)
1129
-
1130
- # return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/pitch_to_audio_utils.py DELETED
@@ -1,121 +0,0 @@
1
- import math
2
- import librosa as li
3
- import torch
4
- from tqdm import tqdm
5
- import numpy as np
6
- import gin
7
- import logging
8
-
9
- import pdb
10
-
11
- @gin.configurable
12
- def torch_stft(x, nfft):
13
- window = torch.hann_window(nfft).to(x)
14
- x = torch.stft(
15
- x,
16
- n_fft=nfft,
17
- hop_length=nfft // 4,
18
- win_length=nfft,
19
- window=window,
20
- center=True,
21
- return_complex=True,
22
- )
23
- x = 2 * x / torch.mean(window)
24
- return x
25
-
26
- @gin.configurable
27
- def torch_istft(x, nfft):
28
- # pdb.set_trace()
29
- window = torch.hann_window(nfft).to(x.device)
30
- x = x / 2 * torch.mean(window)
31
- return torch.istft(
32
- x,
33
- n_fft=nfft,
34
- hop_length=nfft // 4,
35
- win_length=nfft,
36
- window=window,
37
- center=True,
38
- )
39
-
40
- @gin.configurable
41
- def to_mels(stft, nfft, num_mels, sr, eps=1e-2):
42
- mels = li.filters.mel(
43
- sr=sr,
44
- n_fft=nfft,
45
- n_mels=num_mels,
46
- fmin=40,
47
- )
48
- # pdb.set_trace()
49
- mels = torch.from_numpy(mels).to(stft)
50
- mel_stft = torch.einsum("mf,bft->bmt", mels, stft)
51
- mel_stft = torch.log(mel_stft + eps)
52
- return mel_stft
53
-
54
- @gin.configurable
55
- def from_mels(mel_stft, nfft, num_mels, sr, eps=1e-2):
56
- mels = li.filters.mel(
57
- sr=sr,
58
- n_fft=nfft,
59
- n_mels=num_mels,
60
- fmin=40,
61
- )
62
- mels = torch.from_numpy(mels).to(mel_stft)
63
- mels = torch.pinverse(mels)
64
- mel_stft = torch.exp(mel_stft) - eps
65
- stft = torch.einsum("fm,bmt->bft", mels, mel_stft)
66
- return stft
67
-
68
- @gin.configurable
69
- def torch_gl(stft, nfft, sr, n_iter):
70
-
71
- def _gl_iter(phase, xs, stft):
72
- del xs
73
- # pdb.set_trace()
74
- c_stft = stft * torch.exp(1j * phase)
75
- rec = torch_istft(c_stft, nfft)
76
- r_stft = torch_stft(rec, nfft)
77
- phase = torch.angle(r_stft)
78
- return phase, None
79
-
80
- phase = torch.rand_like(stft) * 2 * torch.pi
81
-
82
- for _ in tqdm(range(n_iter)):
83
- phase, _ = _gl_iter(phase, None, stft)
84
-
85
- c_stft = stft * torch.exp(1j * phase)
86
- audio = torch_istft(c_stft, nfft)
87
-
88
- return audio
89
-
90
- @gin.configurable
91
- def normalize(x, qt=None):
92
- x_flat = x.reshape(-1, 1)
93
- if qt is None:
94
- logging.warning('No quantile transformer found, returning input')
95
- return x
96
- return torch.Tensor(qt.transform(x_flat).reshape(x.shape))
97
-
98
- @gin.configurable
99
- def unnormalize(x, qt=None):
100
- x_flat = x.reshape(-1, 1)
101
- if qt is None:
102
- logging.warning('No quantile transformer found, returning input')
103
- return x
104
- if isinstance(x_flat, torch.Tensor):
105
- x_flat = x_flat.detach().cpu().numpy()
106
- return torch.Tensor(qt.inverse_transform(x_flat).reshape(x.shape))
107
-
108
- @gin.configurable
109
- def audio_to_normalized_mels(x, nfft, num_mels, sr, qt):
110
- # pdb.set_trace()
111
- stfts = torch_stft(x, nfft=nfft).abs()[..., :-1]
112
- mel_stfts = to_mels(stfts, nfft, num_mels, sr)
113
- return normalize(mel_stfts, qt).to(x)
114
-
115
- @gin.configurable
116
- def normalized_mels_to_audio(x, nfft, num_mels, sr, qt, n_iter=20):
117
- x = unnormalize(x, qt).to(x)
118
- x = from_mels(x, nfft, num_mels, sr)
119
- x = torch.clamp(x, 0, nfft)
120
- x = torch_gl(x, nfft, sr, n_iter=n_iter)
121
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/preprocess_utils.py DELETED
@@ -1,127 +0,0 @@
1
- import subprocess
2
- import numpy as np
3
- import pandas as pd
4
- from typing import Iterable, Tuple, Callable
5
- import multiprocessing
6
- import functools
7
- from itertools import repeat
8
- from protobuf.data_example import AudioExample, DTYPE_TO_PRECISION
9
- import librosa
10
- import pdb
11
- # from memory_profiler import profile
12
-
13
- # @profile
14
- def load_chunk(
15
- row: pd.Series,
16
- n_signal_audio: int,
17
- n_signal_pitch: int,
18
- sr_audio: int,
19
- sr_pitch: int,
20
- error_path: str = None,
21
- ) -> Iterable[np.ndarray]:
22
- audio_path = row['audio_path']
23
- csv_path = row['pitch_path']
24
- # print (audio_path, csv_path)
25
- # pdb.set_trace()
26
- try:
27
- chunk_csv = pd.read_csv(csv_path, chunksize=n_signal_pitch)
28
- except:
29
- if error_path is not None:
30
- with open(error_path, 'a') as f:
31
- f.write(f'Error reading {csv_path}\n')
32
- return
33
- chunk_iter = iter(chunk_csv)
34
-
35
- chunk_pitch = next(chunk_iter)
36
- f0 = chunk_pitch['filtered_f0'].fillna(0).to_numpy()
37
-
38
- # print('Number of chunks: ', pd.read_csv(csv_path).shape[0]//n_signal_pitch, '\n')
39
- while len(f0) == n_signal_pitch:
40
- start_time = chunk_pitch['time'].values[0]
41
- # print(start_time, chunk_pitch['time'].values[-1] - ((n_signal_pitch - 1)/sr_pitch))
42
- assert abs(start_time - (chunk_pitch['time'].values[-1] - ((n_signal_pitch - 1)/sr_pitch))) < 1e-6 # check that no time stamps were skipped
43
- chunk_audio = librosa.load(audio_path, sr=sr_audio, offset=start_time, duration=n_signal_audio/sr_audio, dtype=np.float32)[0]
44
- assert chunk_audio.shape[0] == n_signal_audio
45
- # and len(f0) == n_signal_pitch:
46
- # chunk_audio /= 2**15
47
- # pdb.set_trace()
48
- yield chunk_audio, f0, row, start_time
49
- try:
50
- chunk_pitch = next(chunk_iter)
51
- f0 = chunk_pitch['filtered_f0'].fillna(0).to_numpy()
52
- except StopIteration:
53
- return
54
-
55
-
56
- def flatmap(
57
- pool: multiprocessing.Pool,
58
- func: Callable,
59
- iterable: Iterable,
60
- queue_size: int,
61
- chunksize=None,
62
- ):
63
- queue = multiprocessing.Manager().Queue(maxsize=queue_size)
64
- pool.map_async(
65
- functools.partial(flat_mappper, func),
66
- zip(iterable, repeat(queue)),
67
- chunksize,
68
- lambda _: queue.put(None),
69
- lambda *e: print(e),
70
- )
71
-
72
- item = queue.get()
73
- while item is not None:
74
- # print(item)
75
- yield item
76
- item = queue.get()
77
-
78
- def flat_mappper(func, arg):
79
- data, queue = arg
80
- for item in func(data):
81
- queue.put(item)
82
-
83
- def batch(iterator: Iterable, batch_size: int):
84
- batch = []
85
- for elm in iterator:
86
- batch.append(elm)
87
- if len(batch) == batch_size:
88
- yield batch
89
- batch = []
90
- if len(batch):
91
- yield batch
92
-
93
- def preprocess_batch(
94
- preprocessed_array,
95
- sr_audio: int,
96
- sr_pitch: int,
97
- ):
98
- # pdb.set_trace()
99
- dtype = np.float32
100
- data_examples = [AudioExample() for _ in range(len(preprocessed_array))]
101
- for ae, data in zip(data_examples, preprocessed_array):
102
- # pdb.set_trace()
103
- audio_data, csv_data, row, start_time = data
104
- buffer_audio = ae.ae.buffers['audio']
105
- buffer_audio.data = audio_data.astype(dtype).tobytes()
106
- buffer_audio.shape.extend(audio_data.shape)
107
- buffer_audio.precision = DTYPE_TO_PRECISION[dtype]
108
- buffer_audio.sampling_rate = sr_audio
109
- buffer_audio.data_path = row['audio_path']
110
- buffer_audio.start_time = start_time
111
-
112
- buffer_csv = ae.ae.buffers['pitch']
113
- buffer_csv.data = csv_data.astype(dtype).tobytes()
114
- buffer_csv.shape.extend(csv_data.shape)
115
- buffer_csv.precision = DTYPE_TO_PRECISION[dtype]
116
- buffer_csv.sampling_rate = sr_pitch
117
- buffer_csv.data_path = row['pitch_path']
118
- buffer_csv.start_time = start_time
119
-
120
- ae.ae.global_conditions.tonic = row['tonic']
121
- ae.ae.global_conditions.raga = row['raga']
122
- ae.ae.global_conditions.singer = row['singer']
123
-
124
- return data_examples
125
-
126
-
127
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/process_encodec.py DELETED
@@ -1,22 +0,0 @@
1
- import gin
2
- from sklearn.preprocessing import QuantileTransformer
3
- from transformers import EncodecModel, AutoProcessor
4
- import librosa as li
5
-
6
- import pdb
7
-
8
- @gin.configurable
9
- def read_tokens(
10
- inputs,
11
- encodec_model: EncodecModel,
12
- encodec_processor: AutoProcessor,
13
- target_bandwidth: int = 3
14
- ):
15
- # pdb.set_trace()
16
- audio = inputs['audio']['data']
17
- audio = li.resample(y=audio, orig_sr= inputs['audio']['sampling_rate'], target_sr=encodec_processor.sampling_rate)
18
-
19
- encodec_inputs = encodec_processor(raw_audio=audio, sampling_rate=encodec_processor.sampling_rate, return_tensors='pt')
20
- encodec_tokens = encodec_model.encode(encodec_inputs['input_values'], bandwidth=target_bandwidth).audio_codes
21
-
22
- return encodec_tokens.detach().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils.py DELETED
@@ -1,65 +0,0 @@
1
- from pathlib import Path
2
- import os
3
- import random
4
- import torch
5
- import numpy as np
6
- import gin
7
-
8
- def search_for_run(run_path, mode="last"):
9
- if run_path is None: return None
10
- if ".ckpt" in run_path: return run_path
11
- ckpts = map(str, Path(run_path).rglob("*.ckpt"))
12
- ckpts = filter(lambda e: mode in os.path.basename(str(e)), ckpts)
13
- ckpts = sorted(ckpts)
14
- if len(ckpts):
15
- if len(ckpts) > 1 and 'last.ckpt' in ckpts:
16
- return ckpts[-2] # last.ckpt is always at the end, so we take the second last
17
- else:
18
- return ckpts[-1]
19
- else: return None
20
-
21
- def set_seed(seed: int):
22
- """Set seed"""
23
- random.seed(seed)
24
- np.random.seed(seed)
25
- torch.manual_seed(seed)
26
- if torch.cuda.is_available():
27
- torch.cuda.manual_seed(seed)
28
- torch.cuda.manual_seed_all(seed)
29
- torch.backends.cudnn.deterministic = True
30
- torch.backends.cudnn.benchmark = False
31
- os.environ["PYTHONHASHSEED"] = str(seed)
32
-
33
- @gin.configurable
34
- def build_warmed_exponential_lr_scheduler(
35
- optim: torch.optim.Optimizer, start_factor: float, peak_iteration: int,
36
- decay_factor: float=None, cycle_length: int=None, eta_min: float=None, eta_max: float=None) -> torch.optim.lr_scheduler._LRScheduler:
37
- linear = torch.optim.lr_scheduler.LinearLR(
38
- optim,
39
- start_factor=start_factor,
40
- end_factor=1.,
41
- total_iters=peak_iteration,
42
- )
43
- if decay_factor:
44
- exp = torch.optim.lr_scheduler.ExponentialLR(
45
- optim,
46
- gamma=decay_factor,
47
- )
48
- return torch.optim.lr_scheduler.SequentialLR(optim, [linear, exp],
49
- milestones=[peak_iteration])
50
- if cycle_length:
51
- cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
52
- optim,
53
- T_max=cycle_length,
54
- eta_min = eta_min * eta_max
55
- )
56
- return torch.optim.lr_scheduler.SequentialLR(optim, [linear, cosine],
57
- milestones=[peak_iteration])
58
-
59
- def prob_mask_like(shape, prob, device):
60
- if prob == 1:
61
- return torch.ones(shape, device = device, dtype = torch.bool)
62
- elif prob == 0:
63
- return torch.zeros(shape, device = device, dtype = torch.bool)
64
- else:
65
- return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob