Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import librosa | |
import soundfile as sf | |
import wavio | |
import os | |
import subprocess | |
import pickle | |
import torch | |
import torch.nn as nn | |
from transformers import T5Tokenizer | |
from transformer_model import Transformer | |
def save_wav(filepath): | |
# Extract the directory and the stem (filename without extension) | |
directory = os.path.dirname(filepath) | |
stem = os.path.splitext(os.path.basename(filepath))[0] | |
# Construct the full paths for MIDI and WAV files | |
midi_filepath = os.path.join(directory, f"{stem}.mid") | |
wav_filepath = os.path.join(directory, f"{stem}.wav") | |
# Run the fluidsynth command to convert MIDI to WAV | |
process = subprocess.Popen( | |
f"fluidsynth -r 16000 soundfont.sf2 -g 1.0 --quiet --no-shell {midi_filepath} -T wav -F {wav_filepath} > /dev/null", | |
shell=True | |
) | |
process.wait() | |
return wav_filepath | |
def generate_midi(caption, temperature=0.9, max_len=500): | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
artifact_folder = 'artifacts' | |
tokenizer_filepath = os.path.join(artifact_folder, "vocab_remi.pkl") | |
# Load the tokenizer dictionary | |
with open(tokenizer_filepath, "rb") as f: | |
r_tokenizer = pickle.load(f) | |
# Get the vocab size | |
vocab_size = len(r_tokenizer) | |
print("Vocab size: ", vocab_size) | |
model = Transformer(vocab_size, 768, 8, 5000, 18, 1024, False, 8, device=device) | |
model_path = os.path.join(artifact_folder, "pytorch_model_95.bin") | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model.eval() | |
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base") | |
inputs = tokenizer(caption, return_tensors='pt', padding=True, truncation=True) | |
input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0) | |
input_ids = input_ids.to(device) | |
attention_mask =nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0) | |
attention_mask = attention_mask.to(device) | |
output = model.generate(input_ids, attention_mask, max_len=max_len,temperature = temperature) | |
output_list = output[0].tolist() | |
generated_midi = r_tokenizer.decode(output_list) | |
print(generated_midi) | |
generated_midi.dump_midi("output.mid") | |
# @spaces.GPU(duration=120) | |
# def gradio_generate(prompt, temperature, max_length): | |
# # Generate midi | |
# generate_midi(prompt, temperature, max_length) | |
# # Convert midi to wav | |
# filename = "output.mid" | |
# save_wav(filename) | |
# filename = filename.replace(".mid", ".wav") | |
# # Read the generated WAV file | |
# output_wave, samplerate = sf.read(filename, dtype='float32') | |
# output_filename = "temp.wav" | |
# wavio.write(output_filename, output_wave, rate=16000, sampwidth=2) | |
# return output_filename | |
def gradio_generate(prompt, temperature, max_length): | |
# Generate midi | |
generate_midi(prompt, temperature, max_length) | |
# Convert midi to wav | |
midi_filename = "output.mid" | |
save_wav(midi_filename) | |
wav_filename = midi_filename.replace(".mid", ".wav") | |
# Read the generated WAV file | |
output_wave, samplerate = sf.read(wav_filename, dtype='float32') | |
temp_wav_filename = "temp.wav" | |
wavio.write(temp_wav_filename, output_wave, rate=16000, sampwidth=2) | |
return temp_wav_filename, midi_filename # Return both WAV and MIDI file paths | |
title="Text2midi: Generating Symbolic Music from Captions" | |
description_text = """ | |
<p><a href="https://huggingface.co/spaces/amaai-lab/text2midi/blob/main/app.py?duplicate=true"> <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> For faster inference without waiting in queue, you may duplicate the space and upgrade to a GPU in the settings. <br/><br/> | |
Generate midi music using Text2midi by providing a text prompt. | |
<br/><br/> This is the demo for Text2midi for controllable text to midi generation: <a href="https://arxiv.org/abs/tbd">Read our paper.</a> | |
<p/> | |
""" | |
#description_text = "" | |
# Gradio input and output components | |
input_text = gr.Textbox(lines=2, label="Prompt") | |
output_audio = gr.Audio(label="Generated Music", type="filepath") | |
output_midi = gr.File(label="Download MIDI File") | |
temperature = gr.Slider(minimum=0.5, maximum=1.2, value=1.0, step=0.1, label="Temperature", interactive=True) | |
max_length = gr.Number(value=500, label="Max Length", minimum=100, maximum=2000, step=100) | |
# CSS styling for the Duplicate button | |
css = ''' | |
#duplicate-button { | |
margin: auto; | |
color: white; | |
background: #1565c0; | |
border-radius: 100vh; | |
} | |
''' | |
# Gradio interface | |
gr_interface = gr.Interface( | |
fn=gradio_generate, | |
inputs=[input_text, temperature, max_length], | |
outputs=[output_audio, output_midi], | |
description=description_text, | |
allow_flagging=False, | |
examples=[ | |
["A cheerful and melodic pop Christmas song featuring piano, acoustic guitar, vibraphone, bass, and drums, set in the key of Eb minor with a fast tempo of 123 bpm and a 4/4 time signature, creating a joyful and relaxing atmosphere."], | |
["A melodic electronic song with ambient elements, featuring piano, acoustic guitar, alto saxophone, string ensemble, and electric bass. Set in G minor with a 4/4 time signature, it moves at a lively Presto tempo. The composition evokes a blend of relaxation and darkness, with hints of happiness and a meditative quality."], | |
["This motivational electronic and pop song features a clean electric guitar, rock organ, synth voice, acoustic guitar, and vibraphone, creating a melodic and uplifting atmosphere. Set in the key of G# minor with a 4/4 time signature, the track moves at an energetic Allegro tempo of 120 beats per minute. The chord progression of Bbm7 and F# adds to the song's inspiring and corporate feel."], | |
["This short electronic song in C minor features a brass section, string ensemble, tenor saxophone, clean electric guitar, and slap bass, creating a melodic and slightly dark atmosphere. With a tempo of 124 BPM (Allegro) and a 4/4 time signature, the track incorporates a chord progression of C7/E, Eb6, and Bbm6, adding a touch of corporate and motivational vibes to the overall composition."], | |
["An energetic and melodic electronic trance track with a space and retro vibe, featuring drums, distortion guitar, flute, synth bass, and slap bass. Set in A minor with a fast tempo of 138 BPM, the song maintains a 4/4 time signature throughout its duration."], | |
["A short but energetic rock fragment in C minor, featuring overdriven guitars, electric bass, and drums, with a vivacious tempo of 155 BPM and a 4/4 time signature, evoking a blend of dark and melodic tones."], | |
], | |
cache_examples="lazy", | |
) | |
with gr.Blocks(css=css) as demo: | |
title=gr.HTML(f"<h1><center>{title}</center></h1>") | |
dupe = gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") | |
gr_interface.render() | |
# Launch Gradio app. | |
demo.queue().launch() | |