Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -41,7 +41,7 @@ def sample_top_p_k(probs, p, k):
|
|
41 |
return next_token
|
42 |
|
43 |
|
44 |
-
def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
45 |
disable_patch_change=False, disable_control_change=False, disable_channels=None):
|
46 |
if disable_channels is not None:
|
47 |
disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
|
@@ -63,7 +63,7 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
|
63 |
with bar:
|
64 |
while cur_len < max_len:
|
65 |
end = False
|
66 |
-
hidden =
|
67 |
next_token_seq = np.empty((1, 0), dtype=np.int64)
|
68 |
event_name = ""
|
69 |
for i in range(max_token_seq):
|
@@ -81,7 +81,7 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
|
81 |
if param_name == "channel":
|
82 |
mask_ids = [i for i in mask_ids if i not in disable_channels]
|
83 |
mask[mask_ids] = 1
|
84 |
-
logits =
|
85 |
scores = softmax(logits / temp, -1) * mask
|
86 |
sample = sample_top_p_k(scores, top_p, top_k)
|
87 |
if i == 0:
|
@@ -107,7 +107,7 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
|
107 |
break
|
108 |
|
109 |
|
110 |
-
def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
|
111 |
mid_seq = []
|
112 |
max_len = int(gen_events)
|
113 |
img_len = 1024
|
@@ -172,7 +172,8 @@ def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, t
|
|
172 |
for token_seq in mid:
|
173 |
mid_seq.append(token_seq)
|
174 |
draw_event(token_seq)
|
175 |
-
|
|
|
176 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
177 |
disable_channels=disable_channels)
|
178 |
for token_seq in generator:
|
@@ -208,13 +209,18 @@ if __name__ == "__main__":
|
|
208 |
parser.add_argument("--max-gen", type=int, default=1024, help="max")
|
209 |
opt = parser.parse_args()
|
210 |
soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
211 |
-
|
212 |
-
|
213 |
-
|
|
|
214 |
tokenizer = MIDITokenizer()
|
215 |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
216 |
-
|
217 |
-
|
|
|
|
|
|
|
|
|
218 |
|
219 |
app = gr.Blocks()
|
220 |
with app:
|
@@ -229,6 +235,8 @@ if __name__ == "__main__":
|
|
229 |
|
230 |
tab_select = gr.Variable(value=0)
|
231 |
with gr.Tabs():
|
|
|
|
|
232 |
with gr.TabItem("instrument prompt") as tab1:
|
233 |
input_instruments = gr.Dropdown(label="instruments (auto if empty)", choices=list(patch2number.keys()),
|
234 |
multiselect=True, max_choices=15, type="value")
|
@@ -260,7 +268,7 @@ if __name__ == "__main__":
|
|
260 |
with gr.Accordion("options", open=False):
|
261 |
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
262 |
input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
|
263 |
-
input_top_k = gr.Slider(label="top k", minimum=1, maximum=20, step=1, value=
|
264 |
input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
|
265 |
example3 = gr.Examples([[1, 0.98, 12], [1.2, 0.95, 8]], [input_temp, input_top_p, input_top_k])
|
266 |
run_btn = gr.Button("generate", variant="primary")
|
@@ -269,8 +277,8 @@ if __name__ == "__main__":
|
|
269 |
output_midi_img = gr.Image(label="output image")
|
270 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
271 |
output_audio = gr.Audio(label="output audio", format="mp3")
|
272 |
-
run_event = run_btn.click(run, [tab_select, input_instruments, input_drum_kit, input_midi,
|
273 |
-
input_gen_events, input_temp, input_top_p, input_top_k,
|
274 |
input_allow_cc],
|
275 |
[output_midi_seq, output_midi_img, output_midi, output_audio])
|
276 |
stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio], cancels=run_event, queue=False)
|
|
|
41 |
return next_token
|
42 |
|
43 |
|
44 |
+
def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
45 |
disable_patch_change=False, disable_control_change=False, disable_channels=None):
|
46 |
if disable_channels is not None:
|
47 |
disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
|
|
|
63 |
with bar:
|
64 |
while cur_len < max_len:
|
65 |
end = False
|
66 |
+
hidden = model[0].run(None, {'x': input_tensor})[0][:, -1]
|
67 |
next_token_seq = np.empty((1, 0), dtype=np.int64)
|
68 |
event_name = ""
|
69 |
for i in range(max_token_seq):
|
|
|
81 |
if param_name == "channel":
|
82 |
mask_ids = [i for i in mask_ids if i not in disable_channels]
|
83 |
mask[mask_ids] = 1
|
84 |
+
logits = model[1].run(None, {'x': next_token_seq, "hidden": hidden})[0][:, -1:]
|
85 |
scores = softmax(logits / temp, -1) * mask
|
86 |
sample = sample_top_p_k(scores, top_p, top_k)
|
87 |
if i == 0:
|
|
|
107 |
break
|
108 |
|
109 |
|
110 |
+
def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
|
111 |
mid_seq = []
|
112 |
max_len = int(gen_events)
|
113 |
img_len = 1024
|
|
|
172 |
for token_seq in mid:
|
173 |
mid_seq.append(token_seq)
|
174 |
draw_event(token_seq)
|
175 |
+
model = models[model_name]
|
176 |
+
generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
177 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
178 |
disable_channels=disable_channels)
|
179 |
for token_seq in generator:
|
|
|
209 |
parser.add_argument("--max-gen", type=int, default=1024, help="max")
|
210 |
opt = parser.parse_args()
|
211 |
soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
212 |
+
models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
|
213 |
+
"symphony finetune model": ["skytnt/midi-model-ft", "symphony/"],
|
214 |
+
"touhou finetune model": ["skytnt/midi-model-ft", "touhou/"]}
|
215 |
+
models = {}
|
216 |
tokenizer = MIDITokenizer()
|
217 |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
218 |
+
for name, (repo_id, path) in models_info.items():
|
219 |
+
model_base_path = hf_hub_download(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
|
220 |
+
model_token_path = hf_hub_download(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
|
221 |
+
model_base = rt.InferenceSession(model_base_path, providers=providers)
|
222 |
+
model_token = rt.InferenceSession(model_token_path, providers=providers)
|
223 |
+
models[name] = [model_base, model_token]
|
224 |
|
225 |
app = gr.Blocks()
|
226 |
with app:
|
|
|
235 |
|
236 |
tab_select = gr.Variable(value=0)
|
237 |
with gr.Tabs():
|
238 |
+
input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
|
239 |
+
type="value", value=list(models.keys())[0])
|
240 |
with gr.TabItem("instrument prompt") as tab1:
|
241 |
input_instruments = gr.Dropdown(label="instruments (auto if empty)", choices=list(patch2number.keys()),
|
242 |
multiselect=True, max_choices=15, type="value")
|
|
|
268 |
with gr.Accordion("options", open=False):
|
269 |
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
270 |
input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
|
271 |
+
input_top_k = gr.Slider(label="top k", minimum=1, maximum=20, step=1, value=20)
|
272 |
input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
|
273 |
example3 = gr.Examples([[1, 0.98, 12], [1.2, 0.95, 8]], [input_temp, input_top_p, input_top_k])
|
274 |
run_btn = gr.Button("generate", variant="primary")
|
|
|
277 |
output_midi_img = gr.Image(label="output image")
|
278 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
279 |
output_audio = gr.Audio(label="output audio", format="mp3")
|
280 |
+
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_midi,
|
281 |
+
input_midi_events, input_gen_events, input_temp, input_top_p, input_top_k,
|
282 |
input_allow_cc],
|
283 |
[output_midi_seq, output_midi_img, output_midi, output_audio])
|
284 |
stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio], cancels=run_event, queue=False)
|