the-jam-machine-app / playground.py
misnaej's picture
changed instrument mapping
d08186b
raw
history blame
No virus
5.3 kB
import matplotlib.pyplot as plt
import gradio as gr
from load import LoadModel
from generate import GenerateMidiText
from constants import INSTRUMENT_CLASSES, INSTRUMENT_TRANSFER_CLASSES
from decoder import TextDecoder
from utils import get_miditok, index_has_substring
from playback import get_music
from matplotlib import pylab
import sys
import matplotlib
from generation_utils import plot_piano_roll
import numpy as np
matplotlib.use("Agg")
sys.modules["pylab"] = pylab
model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53"
revision = "ddf00f90d6d27e4cc0cb99c04a22a8f0a16c933e"
n_bar_generated = 8
# model_repo = "JammyMachina/improved_4bars-mdl"
# n_bar_generated = 4
model, tokenizer = LoadModel(
model_repo, from_huggingface=True, revision=revision
).load_model_and_tokenizer()
miditok = get_miditok()
decoder = TextDecoder(miditok)
def define_prompt(state, genesis):
if len(state) == 0:
input_prompt = "PIECE_START "
else:
input_prompt = genesis.get_whole_piece_from_bar_dict()
return input_prompt
def generator(
label,
regenerate,
temp,
density,
instrument,
state,
piece_by_track,
add_bars=False,
add_bar_count=1,
):
genesis = GenerateMidiText(model, tokenizer, piece_by_track)
track = {"label": label}
inst = next(
(
inst
for inst in INSTRUMENT_TRANSFER_CLASSES
if inst["transfer_to"] == instrument
),
{"family_number": "DRUMS"},
)["family_number"]
inst_index = -1 # default to last generated
if state != []:
for index, instrum in enumerate(state):
if instrum["label"] == track["label"]:
inst_index = index # changing if exists
# Generate
if not add_bars:
# Regenerate
if regenerate:
state.pop(inst_index)
genesis.delete_one_track(inst_index)
generated_text = (
genesis.get_whole_piece_from_bar_dict()
) # maybe not useful here
inst_index = -1 # reset to last generated
# NEW TRACK
input_prompt = define_prompt(state, genesis)
generated_text = genesis.generate_one_new_track(
inst, density, temp, input_prompt=input_prompt
)
regenerate = True # set generate to true
else:
# NEW BARS
genesis.generate_n_more_bars(add_bar_count) # for all instruments
generated_text = genesis.get_whole_piece_from_bar_dict()
decoder.get_midi(generated_text, "mixed.mid")
mixed_inst_midi, mixed_audio = get_music("mixed.mid")
inst_text = genesis.get_selected_track_as_text(inst_index)
inst_midi_name = f"{instrument}.mid"
decoder.get_midi(inst_text, inst_midi_name)
_, inst_audio = get_music(inst_midi_name)
piano_roll = plot_piano_roll(mixed_inst_midi)
track["text"] = inst_text
state.append(track)
return (
inst_text,
(44100, inst_audio),
piano_roll,
state,
(44100, mixed_audio),
regenerate,
genesis.piece_by_track,
)
def instrument_row(default_inst, row_id):
with gr.Row():
row = gr.Variable(row_id)
with gr.Column(scale=1, min_width=50):
inst = gr.Dropdown(
[inst["transfer_to"] for inst in INSTRUMENT_TRANSFER_CLASSES]
+ ["Drums"],
value=default_inst,
label="Instrument",
)
temp = gr.Number(value=0.7, label="Creativity")
density = gr.Dropdown([0, 1, 2, 3], value=3, label="Density")
with gr.Column(scale=3):
output_txt = gr.Textbox(label="output", lines=10, max_lines=10)
with gr.Column(scale=1, min_width=100):
inst_audio = gr.Audio(label="Audio")
regenerate = gr.Checkbox(value=False, label="Regenerate", visible=False)
# add_bars = gr.Checkbox(value=False, label="Add Bars")
# add_bar_count = gr.Dropdown([1, 2, 4, 8], value=1, label="Add Bars")
gen_btn = gr.Button("Generate")
gen_btn.click(
fn=generator,
inputs=[row, regenerate, temp, density, inst, state, piece_by_track],
outputs=[
output_txt,
inst_audio,
piano_roll,
state,
mixed_audio,
regenerate,
piece_by_track,
],
)
with gr.Blocks() as demo:
piece_by_track = gr.State([])
state = gr.State([])
mixed_audio = gr.Audio(label="Mixed Audio")
piano_roll = gr.Plot(label="Piano Roll")
instrument_row("Drums", 0)
instrument_row("Synth Bass 1", 1)
instrument_row("Synth Lead Square", 2)
# instrument_row("Piano")
demo.launch(debug=True)
"""
TODO: reset button
TODO: add a button to save the generated midi
TODO: add improvise button
TODO: set values for temperature as it is done for density
TODO: Add bar should be now set for the whole piece - regenerrate should regenerate the added bars only on all instruments
TODO: row height to fix
TODO: reset state of tick boxes after used maybe (regenerate, add bars) ;
TODO: block regenerate if add bar on
"""