Spaces:
Sleeping
Sleeping
changes
Browse files- app.py +3 -4
- midi_model.py +1 -1
app.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
from concurrent.futures import ThreadPoolExecutor
|
2 |
-
|
3 |
import spaces
|
4 |
import random
|
5 |
import argparse
|
@@ -7,6 +5,7 @@ import glob
|
|
7 |
import json
|
8 |
import os
|
9 |
import time
|
|
|
10 |
|
11 |
import gradio as gr
|
12 |
import numpy as np
|
@@ -122,7 +121,7 @@ def send_msgs(msgs):
|
|
122 |
def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
|
123 |
time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
|
124 |
remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
125 |
-
t = 1e-4*gen_events**2 +
|
126 |
if "large" in model_name:
|
127 |
t *= 2
|
128 |
return t
|
@@ -383,7 +382,7 @@ if __name__ == "__main__":
|
|
383 |
with app:
|
384 |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer</h1>")
|
385 |
gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=skytnt.midi-composer&style=flat)\n\n"
|
386 |
-
"Midi event transformer for music generation\n\n"
|
387 |
"Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
|
388 |
"[Open In Colab]"
|
389 |
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
|
|
|
|
|
|
1 |
import spaces
|
2 |
import random
|
3 |
import argparse
|
|
|
5 |
import json
|
6 |
import os
|
7 |
import time
|
8 |
+
from concurrent.futures import ThreadPoolExecutor
|
9 |
|
10 |
import gradio as gr
|
11 |
import numpy as np
|
|
|
121 |
def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
|
122 |
time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
|
123 |
remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
124 |
+
t = 1e-4*gen_events**2 + 25
|
125 |
if "large" in model_name:
|
126 |
t *= 2
|
127 |
return t
|
|
|
382 |
with app:
|
383 |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer</h1>")
|
384 |
gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=skytnt.midi-composer&style=flat)\n\n"
|
385 |
+
"Midi event transformer for symbolic music generation\n\n"
|
386 |
"Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
|
387 |
"[Open In Colab]"
|
388 |
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
midi_model.py
CHANGED
@@ -125,7 +125,7 @@ class MIDIModel(nn.Module):
|
|
125 |
prompt = np.repeat(prompt, repeats=batch_size, axis=0)
|
126 |
elif prompt.shape[0] == 1:
|
127 |
prompt = np.repeat(prompt, repeats=batch_size, axis=0)
|
128 |
-
|
129 |
raise ValueError(f"invalid shape for prompt, {prompt.shape}")
|
130 |
prompt = prompt[..., :max_token_seq]
|
131 |
if prompt.shape[-1] < max_token_seq:
|
|
|
125 |
prompt = np.repeat(prompt, repeats=batch_size, axis=0)
|
126 |
elif prompt.shape[0] == 1:
|
127 |
prompt = np.repeat(prompt, repeats=batch_size, axis=0)
|
128 |
+
elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size:
|
129 |
raise ValueError(f"invalid shape for prompt, {prompt.shape}")
|
130 |
prompt = prompt[..., :max_token_seq]
|
131 |
if prompt.shape[-1] < max_token_seq:
|