File size: 17,338 Bytes
e6346a3
 
 
 
 
 
 
d3c7fa1
e6346a3
 
 
 
 
ee8cf1f
 
e6346a3
 
ee8cf1f
e6346a3
f7854be
97b6f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98eb218
e6346a3
 
 
 
e0d48d1
e6346a3
8505dc9
3d6b478
 
 
 
e0d48d1
01188ff
98eb218
97b6f36
ee8cf1f
 
 
 
 
97b6f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0821a2f
98eb218
 
0821a2f
 
 
e0d48d1
 
 
 
3d6b478
e0d48d1
ee8cf1f
3d6b478
 
 
 
ee8cf1f
97b6f36
 
 
98eb218
 
97b6f36
 
 
98eb218
97b6f36
f7854be
4aaba91
3d6b478
60db161
 
 
 
97b6f36
ee8cf1f
 
3d6b478
 
a16f46b
e0d48d1
3d6b478
e0d48d1
0821a2f
ee8cf1f
e0d48d1
 
ee8cf1f
97b6f36
98eb218
97b6f36
d3c7fa1
ee8cf1f
97b6f36
3d6b478
97b6f36
ee8cf1f
3d6b478
98eb218
97b6f36
a16f46b
98eb218
 
97b6f36
98eb218
 
e0d48d1
 
98eb218
e0d48d1
98eb218
e0d48d1
98eb218
 
 
e6346a3
60db161
98eb218
97b6f36
e0d48d1
 
 
3d6b478
e0d48d1
3d6b478
 
e0d48d1
3d6b478
60db161
e0d48d1
 
 
ee8cf1f
0821a2f
e0d48d1
 
 
 
 
 
 
 
97b6f36
 
 
e0d48d1
 
97b6f36
e0d48d1
 
 
97b6f36
 
 
 
98eb218
97b6f36
98eb218
a16f46b
98eb218
 
 
 
 
 
 
 
 
 
 
ee8cf1f
a16f46b
98eb218
 
 
 
 
e0d48d1
60b2974
 
 
 
e0d48d1
3d6b478
e0d48d1
3d6b478
e0d48d1
 
017b2a5
e0d48d1
017b2a5
 
 
e0d48d1
 
 
 
 
0821a2f
 
 
 
 
 
3d6b478
 
 
 
 
 
 
0821a2f
e0d48d1
 
 
 
 
 
 
 
 
 
017b2a5
d607f42
 
 
9272247
 
017b2a5
e0d48d1
d607f42
ad4d894
d607f42
017b2a5
0821a2f
017b2a5
e0d48d1
4aaba91
e0d48d1
 
 
 
3d6b478
e0d48d1
3d6b478
0821a2f
 
60b2974
e0d48d1
97b6f36
5c695ee
9272247
 
 
 
 
 
 
 
 
 
97b6f36
9272247
e0d48d1
 
 
 
 
d607f42
 
9272247
0821a2f
97b6f36
 
 
a16f46b
97b6f36
 
e6346a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import sys
import os
# Check if running in debug mode
debug_mode = '--debug' in sys.argv or os.environ.get('DEBUG') == 'True'

if debug_mode:
    # Path to the local version of the package
    local_package_path = "../../GaMaDHaNi-dev"
    
    # Add the local package path to sys.path
    sys.path.insert(0, local_package_path)
    
    print(f"Running in debug mode. Using package from: {local_package_path}")
    import pyprofilers as pp
    debug_mode = True
else:
    print("Running in normal mode. Using package from site-packages.")
    debug_mode = False

import spaces
import gradio as gr
import numpy as np
import torch
import librosa
import matplotlib.pyplot as plt
import pandas as pd
from functools import partial
import gin
import torchaudio
from absl import app
from torch.nn.functional import interpolate
import logging
import crepe
from hmmlearn import hmm
import soundfile as sf
import pdb
from gamadhani.utils.generate_utils import load_pitch_fns, load_audio_fns
import gamadhani.utils.pitch_to_audio_utils as p2a
from gamadhani.utils.utils import get_device

import copy

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True)
pitch_paths = {
    'Diffusion Pitch Generator': ('diffusion', 'models/diffusion_pitch/'),
    'Autoregressive Pitch Generator': ('transformer', 'models/transformer_pitch/')
    }
model_loaded = None
audio_path = 'models/pitch_to_audio/'
device = get_device()

def debug_profile(func):
    if debug_mode:
        return pp.profile(sort_by='cumulative', out_lines=10)(func)
    return func

def predict_voicing(confidence):
    # https://github.com/marl/crepe/pull/26
    """
    Find the Viterbi path for voiced versus unvoiced frames.
    Parameters
    ----------
    confidence : np.ndarray [shape=(N,)]
        voicing confidence array, i.e. the confidence in the presence of
        a pitch
    Returns
    -------
    voicing_states : np.ndarray [shape=(N,)]
        HMM predictions for each frames state, 0 if unvoiced, 1 if
        voiced
    """
    # uniform prior on the voicing confidence
    starting = np.array([0.5, 0.5])

    # transition probabilities inducing continuous voicing state
    transition = np.array([[0.99, 0.01], [0.01, 0.99]])

    # mean and variance for unvoiced and voiced states
    means = np.array([[0.0], [1.0]])
    variances = np.array([[0.25], [0.25]])

    # fix the model parameters because we are not optimizing the model
    model = hmm.GaussianHMM(n_components=2)
    model.startprob_, model.covars_, model.transmat_, model.means_, \
    model.n_features = starting, variances, transition, means, 1

    # find the Viterbi path
    voicing_states = model.predict(confidence.reshape(-1, 1), [len(confidence)])

    return np.array(voicing_states)

def extract_pitch(audio, unvoice=True, sr=16000, frame_shift_ms=10, log=True):
    time, frequency, confidence, _ = crepe.predict(
      audio, sr=sr,
      viterbi=True,
      step_size=frame_shift_ms,
      verbose=0 if not log else 1)
    f0 = frequency
    if unvoice:
      is_voiced = predict_voicing(confidence)
      frequency_unvoiced = frequency * is_voiced
      f0 = frequency_unvoiced

    return time, f0, confidence

def generate_pitch_reinterp(pitch, pitch_model, invert_pitch_fn, num_samples, num_steps, noise_std=0.4, t0=0.5):
    '''Generate pitch values for the melodic reinterpretation task'''
    # hardcoding the amount of noise to be added 
    # noisy_pitch = torch.Tensor(pitch[:, :, -1200:]).to(pitch_model.device) + (torch.normal(mean=0.0, std=noise_std*torch.ones((1200)))).to(pitch_model.device)
    # noisy_pitch = torch.clamp(noisy_pitch, -5.19, 5.19)     # clipping the pitch values to be within the range of the model
    samples = pitch_model.sample_sdedit(pitch[:, :, -1200:].to(pitch_model.device), num_samples, num_steps, t0=t0)
    inverted_pitches = invert_pitch_fn(f0=samples.detach().cpu().numpy()[0]).flatten()   # pitch values in Hz

    return samples, inverted_pitches

def generate_pitch_response(pitch, pitch_model, invert_pitch_fn, num_samples, num_steps, model_type='diffusion'):
    '''Generate pitch values for the call and response task'''
    pitch = pitch[:, :, -400:]   # consider only the last 4 s of the pitch contour
    if model_type == 'diffusion':
        samples = pitch_model.sample_fn(num_samples, num_steps, prime=pitch)
    else:
        samples = pitch_model.sample_fn(batch_size=num_samples, seq_len=800, prime=pitch)
    inverted_pitches = invert_pitch_fn(f0=samples.clone().detach().cpu().numpy()[0]).flatten()   # pitch values in Hz

    return samples, inverted_pitches

def generate_audio(audio_model, f0s, invert_audio_fn, singers=[3], num_steps=100):
    '''Generate audio given pitch values'''
    singer_tensor = torch.tensor(np.repeat(singers, repeats=f0s.shape[0])).to(audio_model.device)
    samples, _, singers = audio_model.sample_cfg(f0s.shape[0], f0=f0s, num_steps=num_steps, singer=singer_tensor, strength=3)
    audio = invert_audio_fn(samples)

    return audio
    
@spaces.GPU(duration=30)
def generate(pitch, num_samples=1, num_steps=100, singers=[3], outfolder='temp', audio_seq_len=750, pitch_qt=None, type='response', invert_pitch_fn=None, t0=0.5, model_type='diffusion'):
    global pitch_model, audio_model
    # move the models to device
    pitch_model = pitch_model.to(device)
    audio_model = audio_model.to(device)
    logging.log(logging.INFO, 'Generate function')
    # load pitch values onto GPU
    pitch = torch.tensor(pitch).float().unsqueeze(0).unsqueeze(0).to(device)
    if pitch_qt is not None:
        pitch_qt = p2a.GPUQuantileTransformer(pitch_qt, device=device)
    logging.log(logging.INFO, 'Generating pitch')
    if type == 'response':
        pitch, inverted_pitch = generate_pitch_response(pitch, pitch_model, invert_pitch_fn, num_samples=num_samples, num_steps=100, model_type=model_type)
    elif type == 'reinterp':
        pitch, inverted_pitch = generate_pitch_reinterp(pitch, pitch_model, invert_pitch_fn, num_samples=num_samples, num_steps=100, t0=t0)
    
    else:
        raise ValueError(f'Invalid type: {type}')

    if pitch_qt is not None:
        # if there is not pitch quantile transformer, undo the default quantile transformation that occurs
        def undo_qt(x, min_clip=200):
            pitch= pitch_qt.inverse_transform(x).squeeze(0) # qt transform expects shape (bs, seq_len, 1)
            pitch = torch.round(pitch) # round to nearest integer, done in preprocessing of pitch contour fed into model
            pitch[pitch < 200] = np.nan
            pitch = pitch.unsqueeze(0)
            return pitch
        pitch = undo_qt(pitch)
    interpolated_pitch = p2a.interpolate_pitch(pitch=pitch, audio_seq_len=audio_seq_len).squeeze(0)    # interpolate pitch values to match the audio model's input size
    interpolated_pitch = torch.nan_to_num(interpolated_pitch, nan=196)  # replace nan values with silent token
    interpolated_pitch = interpolated_pitch.squeeze(1) # to match input size by removing the extra dimension
    logging.log(logging.INFO, 'Generating audio')
    audio = generate_audio(audio_model, interpolated_pitch, invert_audio_fn, singers=singers, num_steps=100)
    audio = audio.detach().cpu().numpy()
    pitch = pitch.detach().cpu().numpy()
    # generate plot of model output to display on interface
    model_output_plot = plt.figure()
    inverted_pitch = np.where(inverted_pitch == 0, np.nan, inverted_pitch)
    plt.plot(inverted_pitch, figure=model_output_plot, label='Model Output')
    plt.close(model_output_plot)
    return (16000, audio[0]), model_output_plot # return audio and plot

pitch_model, pitch_qt, pitch_task_fn, invert_pitch_fn = None, None, None, None # initialize pitch model based on user preference
audio_model, audio_qt, audio_seq_len, invert_audio_fn = load_audio_fns(
    os.path.join(audio_path, 'last.ckpt'),
    qt_path = os.path.join(audio_path, 'qt.joblib'),
    config_path = os.path.join(audio_path, 'config.gin'),
    device = 'cpu'
)


def load_pitch_model(model_selection):
    global device
    model_type, pitch_path = pitch_paths[model_selection]
    pitch_model, pitch_qt, pitch_task_fn, invert_pitch_fn, _ = load_pitch_fns(
        os.path.join(pitch_path, 'model.ckpt'), \
        model_type = model_type, \
        config_path = os.path.join(pitch_path, 'config.gin'), \
        qt_path = os.path.join(pitch_path, 'qt.joblib') if model_type == 'diffusion' else None, \
        device = 'cpu'
    )
    return pitch_model, pitch_qt, pitch_task_fn, invert_pitch_fn

@debug_profile
def container_generate(model_selection, task_selection, audio, singer_id, t0):
    global pitch_model, pitch_qt, pitch_task_fn, invert_pitch_fn, model_loaded
    # load pitch model
    if model_loaded is None or model_loaded != model_selection:
        pitch_model, pitch_qt, pitch_task_fn, invert_pitch_fn = load_pitch_model(model_selection)
        model_loaded = model_selection
    else:
        logging.log(logging.INFO, f'using existing model: {model_selection}')
    # extract pitch from input
    if audio is None:
        return None, None
    sr, audio = audio
    if len(audio) < 12*sr and task_selection == 'Melodic Reinterpretation':    
        # make sure the audio is at least 12 s long
        audio = np.pad(audio, (0, 12*sr - len(audio)), mode='constant')
    if len(audio) < 4*sr and task_selection == 'Call and Response':     
        # make sure the audio is at least 4 s long
        audio = np.pad(audio, (4*sr - len(audio), 0), mode='constant')
    audio = audio.astype(np.float32)
    audio /= np.max(np.abs(audio))
    audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) # convert only last 4 s
    mic_audio = audio.copy()
    audio = audio[-12*16000:] # consider only last 12 s
    _, f0, _ = extract_pitch(audio)
    mic_f0 = f0.copy() # save the user input pitch values
    logging.log(logging.INFO, 'Pitch extracted')
    f0 = pitch_task_fn(**{
        'inputs': {
            'pitch': {
                'data': torch.Tensor(f0), # task function expects a tensor
                'sampling_rate': 100
                }
        }, 
        'qt_transform': pitch_qt,
        'time_downsample': 1, # pitch will be extracted at 100 Hz, thus no downsampling
        'seq_len': None,
    })['sampled_sequence']
    # f0 = torch.tensor(f0).to(pitch_model.device).float()
    logging.log(logging.INFO, 'Calling generate function')
    mic_f0 = np.where(mic_f0 == 0, np.nan, mic_f0)
    # plot user input
    user_input_plot = plt.figure()
    plt.plot(np.arange(0, len(mic_f0)), mic_f0, label='User Input', figure=user_input_plot)
    plt.close(user_input_plot)
    
    if singer_id == 'Singer 1':
        singer = [3]
    elif singer_id == 'Singer 2':
        singer = [27]
    if task_selection == 'Call and Response':
        partial_generate = partial(generate, num_samples=1, num_steps=100, singers=singer, outfolder=None, pitch_qt=pitch_qt, type='response', invert_pitch_fn=invert_pitch_fn, model_type=model_selection) 
    else:
        partial_generate = partial(generate, num_samples=1, num_steps=100, singers=singer, outfolder=None, pitch_qt=pitch_qt, type='reinterp', invert_pitch_fn=invert_pitch_fn, t0=t0, model_type=model_selection) 
    audio, output_plot = partial_generate(f0)
    return audio, user_input_plot, output_plot

css = """
.center-text {
    text-align: center;
}
.justify-text {
    text-align: justify;
}   
"""

def toggle_visibility(selection):
    # Show element if selection is "Show", otherwise hide it
    if selection == "Melodic Reinterpretation":
        return gr.update(visible=True)
    else:
        return gr.update(visible=False)
    
def toggle_options(selection, options = ['Call and Response', 'Melodic Reinterpretation']):
    # Show element if selection is "Show", otherwise hide it
    if selection == "Melodic Reinterpretation":
        return gr.update(choices=options)
    else:
        return gr.update(choices=options[:-1])

with gr.Blocks(css=css) as demo:
    gr.Markdown("# GaMaDHaNi: Hierarchical Generative Modeling of Melodic Vocal Contours in Hindustani Classical Music", elem_classes="center-text")
    gr.Markdown("### Abstract", elem_classes="center-text")
    gr.Markdown("""
        Hindustani music is a performance-driven oral tradition that exhibits the rendition of rich melodic patterns. In this paper, we focus on generative modeling of singers' vocal melodies extracted from audio recordings, as the voice is musically prominent within the tradition. Prior generative work in Hindustani music models melodies as coarse discrete symbols which fails to capture the rich expressive melodic intricacies of singing. Thus, we propose to use a finely quantized pitch contour, as an intermediate representation for hierarchical audio modeling. We propose GaMaDHaNi, a modular two-level hierarchy, consisting of a generative model on pitch contours, and a pitch contour to audio synthesis model. We compare our approach to non-hierarchical audio models and hierarchical models that use a self-supervised intermediate representation, through a listening test and qualitative analysis. We also evaluate audio model's ability to faithfully represent the pitch contour input using Pearson correlation coefficient. By using pitch contours as an intermediate representation, we show that our model may be better equipped to listen and respond to musicians in a human-AI collaborative setting by highlighting two potential interaction use cases (1) primed generation, and (2) coarse pitch conditioning.
    """, elem_classes="justify-text")
    gr.Markdown("""
                    πŸ“– Read more about the project [here](https://arxiv.org/pdf/2408.12658) <br>
                    🎧 Listen to the samples [here](https://snnithya.github.io/gamadhani-samples) <br>
    """, elem_classes="center-text")
    with gr.Column():
        gr.Markdown("""
                    ## Instructions
                    In this demo you can interact with the model in two ways: 
                    1. **[Call and response](https://snnithya.github.io/gamadhani-samples/5primed_generation/)**: The model will try to continue the idea that you input. This is similar to 'primed generation' discussed in the paper. The last 4 s of the audio will be considered as a 'prime' for the model to continue. <br><br>
                    2. **[Melodic reinterpretation](https://snnithya.github.io/gamadhani-samples/6coarsepitch/)**: Akin to the idea of 'coarse pitch conditioning' presented in the paper, you can input a pitch contour and the model will generate audio that is similar to but not exactly the same. <br><br>
                    ### Upload an audio file or record your voice to get started!
                    """)
        gr.Markdown("""
                    This is still a work in progress, so please feel free to share any weird or interesting examples, we would love to hear them! Contact us at snnithya[at]mit[dot]edu.
                    """)
        gr.Markdown("""
            *Note: If you see an error message on the screen after clicking 'Run', please wait for five seconds and click 'Run' again.*
        """)
        gr.Markdown("""
            *Another note: The model may take around 20-30s to generate an output. Hang tight! But if you're left hanging for too long, let me know!*
            """)
        gr.Markdown("""
            *Last note, I promise: There are some example audio samples at the bottom of the page. You can start with those if you'd like!*
            """)
    model_dropdown = gr.Dropdown(["Diffusion Pitch Generator", "Autoregressive Pitch Generator"], label="Select a model type")
    task_dropdown = gr.Dropdown(label="Select a task", choices=["Call and Response", "Melodic Reinterpretation"])
    model_dropdown.change(toggle_options, outputs=task_dropdown)
    t0 = gr.Slider(label="Faithfulness to the input (For melodic reinterpretation task only)", minimum=0.0, maximum=1.0, step=0.01, value=0.3, visible=False)
    task_dropdown.change(toggle_visibility, inputs=task_dropdown, outputs=t0)
    singer_dropdown = gr.Dropdown(label="Select a singer", choices=["Singer 1", "Singer 2"])
    with gr.Row(equal_height=True):
        with gr.Column():    
            audio = gr.Audio(label="Input", show_download_button=True)
            examples = gr.Examples(
                examples=[
                    ["examples/ex1.wav"],
                    ["examples/ex2.wav"],
                    ["examples/ex3.wav"],
                    ["examples/ex4.wav"],
                    ["examples/ex5.wav"]
                ],
                inputs=audio
            )
        with gr.Column():
            generated_audio = gr.Audio(label="Generated Audio", elem_id="audio")
    with gr.Row():
        with gr.Column():
            with gr.Accordion("View Pitch Plot"):
                user_input = gr.Plot(label="User Input")  
        with gr.Column():
            with gr.Accordion("View Pitch Plot"):
                generated_pitch = gr.Plot(label="Generated Pitch")
    sbmt = gr.Button()
    sbmt.click(container_generate, inputs=[model_dropdown, task_dropdown, audio, singer_dropdown, t0], outputs=[generated_audio, user_input, generated_pitch])

def main(argv):
    
    demo.launch()

if __name__ == '__main__':
    main(sys.argv)