Spaces:
Paused
Paused
add seed
Browse files
app.py
CHANGED
@@ -14,6 +14,7 @@ import MIDI
|
|
14 |
from midi_synthesizer import synthesis
|
15 |
from midi_tokenizer import MIDITokenizer
|
16 |
|
|
|
17 |
in_space = os.getenv("SYSTEM") == "spaces"
|
18 |
|
19 |
|
@@ -23,7 +24,9 @@ def softmax(x, axis):
|
|
23 |
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
|
24 |
|
25 |
|
26 |
-
def sample_top_p_k(probs, p, k):
|
|
|
|
|
27 |
probs_idx = np.argsort(-probs, axis=-1)
|
28 |
probs_sort = np.take_along_axis(probs, probs_idx, -1)
|
29 |
probs_sum = np.cumsum(probs_sort, axis=-1)
|
@@ -36,17 +39,19 @@ def sample_top_p_k(probs, p, k):
|
|
36 |
shape = probs_sort.shape
|
37 |
probs_sort_flat = probs_sort.reshape(-1, shape[-1])
|
38 |
probs_idx_flat = probs_idx.reshape(-1, shape[-1])
|
39 |
-
next_token = np.stack([
|
40 |
next_token = next_token.reshape(*shape[:-1])
|
41 |
return next_token
|
42 |
|
43 |
|
44 |
def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
45 |
-
disable_patch_change=False, disable_control_change=False, disable_channels=None):
|
46 |
if disable_channels is not None:
|
47 |
disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
|
48 |
else:
|
49 |
disable_channels = []
|
|
|
|
|
50 |
max_token_seq = tokenizer.max_token_seq
|
51 |
if prompt is None:
|
52 |
input_tensor = np.full((1, max_token_seq), tokenizer.pad_id, dtype=np.int64)
|
@@ -83,7 +88,7 @@ def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
|
83 |
mask[mask_ids] = 1
|
84 |
logits = model[1].run(None, {'x': next_token_seq, "hidden": hidden})[0][:, -1:]
|
85 |
scores = softmax(logits / temp, -1) * mask
|
86 |
-
sample = sample_top_p_k(scores, top_p, top_k)
|
87 |
if i == 0:
|
88 |
next_token_seq = sample
|
89 |
eid = sample.item()
|
@@ -120,13 +125,16 @@ def send_msgs(msgs, msgs_history=None):
|
|
120 |
return json.dumps(msgs_history)
|
121 |
|
122 |
|
123 |
-
def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events,
|
|
|
124 |
msgs_history = []
|
125 |
mid_seq = []
|
126 |
bpm = int(bpm)
|
127 |
gen_events = int(gen_events)
|
128 |
max_len = gen_events
|
129 |
-
|
|
|
|
|
130 |
disable_patch_change = False
|
131 |
disable_channels = None
|
132 |
if tab == 0:
|
@@ -159,22 +167,22 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, gen_event
|
|
159 |
init_msgs = [create_msg("visualizer_clear", False)]
|
160 |
for tokens in mid_seq:
|
161 |
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
162 |
-
yield mid_seq, None, None, send_msgs(init_msgs, msgs_history)
|
163 |
model = models[model_name]
|
164 |
-
|
165 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
166 |
-
disable_channels=disable_channels)
|
167 |
-
for i, token_seq in enumerate(
|
168 |
token_seq = token_seq.tolist()
|
169 |
mid_seq.append(token_seq)
|
170 |
event = tokenizer.tokens2event(token_seq)
|
171 |
-
yield mid_seq, None, None, send_msgs([create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])], msgs_history)
|
172 |
mid = tokenizer.detokenize(mid_seq)
|
173 |
with open(f"output.mid", 'wb') as f:
|
174 |
f.write(MIDI.score2midi(mid))
|
175 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
176 |
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
|
177 |
-
yield mid_seq, "output.mid", (44100, audio), send_msgs([create_msg("visualizer_end", events)])
|
178 |
|
179 |
|
180 |
def cancel_run(mid_seq):
|
@@ -232,8 +240,8 @@ if __name__ == "__main__":
|
|
232 |
opt = parser.parse_args()
|
233 |
soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
234 |
models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
|
235 |
-
"j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
|
236 |
-
"touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
|
237 |
}
|
238 |
models = {}
|
239 |
tokenizer = MIDITokenizer()
|
@@ -301,7 +309,10 @@ if __name__ == "__main__":
|
|
301 |
|
302 |
tab1.select(lambda: 0, None, tab_select, queue=False)
|
303 |
tab2.select(lambda: 1, None, tab_select, queue=False)
|
304 |
-
|
|
|
|
|
|
|
305 |
step=1, value=opt.max_gen // 2)
|
306 |
with gr.Accordion("options", open=False):
|
307 |
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
@@ -316,9 +327,9 @@ if __name__ == "__main__":
|
|
316 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
317 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
318 |
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_bpm,
|
319 |
-
input_midi, input_midi_events,
|
320 |
-
input_top_p, input_top_k, input_allow_cc],
|
321 |
-
[output_midi_seq, output_midi, output_audio, js_msg],
|
322 |
concurrency_limit=3)
|
323 |
stop_btn.click(cancel_run, [output_midi_seq], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
|
324 |
app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
|
|
|
14 |
from midi_synthesizer import synthesis
|
15 |
from midi_tokenizer import MIDITokenizer
|
16 |
|
17 |
+
MAX_SEED = np.iinfo(np.int32).max
|
18 |
in_space = os.getenv("SYSTEM") == "spaces"
|
19 |
|
20 |
|
|
|
24 |
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
|
25 |
|
26 |
|
27 |
+
def sample_top_p_k(probs, p, k, generator=None):
|
28 |
+
if generator is None:
|
29 |
+
generator = np.random
|
30 |
probs_idx = np.argsort(-probs, axis=-1)
|
31 |
probs_sort = np.take_along_axis(probs, probs_idx, -1)
|
32 |
probs_sum = np.cumsum(probs_sort, axis=-1)
|
|
|
39 |
shape = probs_sort.shape
|
40 |
probs_sort_flat = probs_sort.reshape(-1, shape[-1])
|
41 |
probs_idx_flat = probs_idx.reshape(-1, shape[-1])
|
42 |
+
next_token = np.stack([generator.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)])
|
43 |
next_token = next_token.reshape(*shape[:-1])
|
44 |
return next_token
|
45 |
|
46 |
|
47 |
def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
48 |
+
disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
|
49 |
if disable_channels is not None:
|
50 |
disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
|
51 |
else:
|
52 |
disable_channels = []
|
53 |
+
if generator is None:
|
54 |
+
generator = np.random
|
55 |
max_token_seq = tokenizer.max_token_seq
|
56 |
if prompt is None:
|
57 |
input_tensor = np.full((1, max_token_seq), tokenizer.pad_id, dtype=np.int64)
|
|
|
88 |
mask[mask_ids] = 1
|
89 |
logits = model[1].run(None, {'x': next_token_seq, "hidden": hidden})[0][:, -1:]
|
90 |
scores = softmax(logits / temp, -1) * mask
|
91 |
+
sample = sample_top_p_k(scores, top_p, top_k, generator)
|
92 |
if i == 0:
|
93 |
next_token_seq = sample
|
94 |
eid = sample.item()
|
|
|
125 |
return json.dumps(msgs_history)
|
126 |
|
127 |
|
128 |
+
def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, seed, seed_rand,
|
129 |
+
gen_events, temp, top_p, top_k, allow_cc):
|
130 |
msgs_history = []
|
131 |
mid_seq = []
|
132 |
bpm = int(bpm)
|
133 |
gen_events = int(gen_events)
|
134 |
max_len = gen_events
|
135 |
+
if seed_rand:
|
136 |
+
seed = np.random.randint(0, MAX_SEED)
|
137 |
+
generator = np.random.RandomState(seed)
|
138 |
disable_patch_change = False
|
139 |
disable_channels = None
|
140 |
if tab == 0:
|
|
|
167 |
init_msgs = [create_msg("visualizer_clear", False)]
|
168 |
for tokens in mid_seq:
|
169 |
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
170 |
+
yield mid_seq, None, None, seed, send_msgs(init_msgs, msgs_history)
|
171 |
model = models[model_name]
|
172 |
+
midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
173 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
174 |
+
disable_channels=disable_channels, generator=generator)
|
175 |
+
for i, token_seq in enumerate(midi_generator):
|
176 |
token_seq = token_seq.tolist()
|
177 |
mid_seq.append(token_seq)
|
178 |
event = tokenizer.tokens2event(token_seq)
|
179 |
+
yield mid_seq, None, None, seed, send_msgs([create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])], msgs_history)
|
180 |
mid = tokenizer.detokenize(mid_seq)
|
181 |
with open(f"output.mid", 'wb') as f:
|
182 |
f.write(MIDI.score2midi(mid))
|
183 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
184 |
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
|
185 |
+
yield mid_seq, "output.mid", (44100, audio), seed, send_msgs([create_msg("visualizer_end", events)])
|
186 |
|
187 |
|
188 |
def cancel_run(mid_seq):
|
|
|
240 |
opt = parser.parse_args()
|
241 |
soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
242 |
models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
|
243 |
+
# "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
|
244 |
+
# "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
|
245 |
}
|
246 |
models = {}
|
247 |
tokenizer = MIDITokenizer()
|
|
|
309 |
|
310 |
tab1.select(lambda: 0, None, tab_select, queue=False)
|
311 |
tab2.select(lambda: 1, None, tab_select, queue=False)
|
312 |
+
input_seed = gr.Slider(label="seed", minimum=0, maximum=2 ** 31 - 1,
|
313 |
+
step=1, value=0)
|
314 |
+
input_seed_rand = gr.Checkbox(label="random seed", value=True)
|
315 |
+
input_gen_events = gr.Slider(label="generate max n midi events", minimum=1, maximum=opt.max_gen,
|
316 |
step=1, value=opt.max_gen // 2)
|
317 |
with gr.Accordion("options", open=False):
|
318 |
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
|
|
327 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
328 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
329 |
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_bpm,
|
330 |
+
input_midi, input_midi_events, input_seed, input_seed_rand, input_gen_events,
|
331 |
+
input_temp, input_top_p, input_top_k, input_allow_cc],
|
332 |
+
[output_midi_seq, output_midi, output_audio, input_seed, js_msg],
|
333 |
concurrency_limit=3)
|
334 |
stop_btn.click(cancel_run, [output_midi_seq], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
|
335 |
app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
|