Spaces:
Paused
Paused
update to onnx
Browse files- .gitattributes +1 -0
- app.py +86 -49
- example/Bach--Fugue-in-D-Minor.mid +3 -0
- example/Beethoven--Symphony-No5-in-C-Minor-Fate-Opus-67.mid +3 -0
- example/Chopin--Nocturne No. 9 in B Major, Opus 32 No.1, Andante Sostenuto.mid +3 -0
- example/Mozart--Requiem, No.1..mid +3 -0
- example/castle_in_the_sky.mid +3 -0
- example/eva-残酷な天使のテーゼ.mid +3 -0
- midi_model.py +0 -123
- requirements.txt +1 -2
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.mid filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -1,48 +1,72 @@
|
|
1 |
import argparse
|
2 |
import glob
|
3 |
-
|
4 |
-
import
|
|
|
|
|
5 |
import gradio as gr
|
6 |
import numpy as np
|
7 |
-
import
|
8 |
-
|
9 |
-
import
|
10 |
import tqdm
|
|
|
11 |
|
12 |
import MIDI
|
13 |
-
from midi_model import MIDIModel
|
14 |
-
from midi_tokenizer import MIDITokenizer
|
15 |
from midi_synthesizer import synthesis
|
16 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
@torch.inference_mode()
|
19 |
def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
20 |
-
disable_patch_change=False, disable_control_change=False, disable_channels=None
|
21 |
if disable_channels is not None:
|
22 |
disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
|
23 |
else:
|
24 |
disable_channels = []
|
25 |
max_token_seq = tokenizer.max_token_seq
|
26 |
if prompt is None:
|
27 |
-
input_tensor =
|
28 |
input_tensor[0, 0] = tokenizer.bos_id # bos
|
29 |
else:
|
30 |
prompt = prompt[:, :max_token_seq]
|
31 |
if prompt.shape[-1] < max_token_seq:
|
32 |
prompt = np.pad(prompt, ((0, 0), (0, max_token_seq - prompt.shape[-1])),
|
33 |
mode="constant", constant_values=tokenizer.pad_id)
|
34 |
-
input_tensor =
|
35 |
-
input_tensor = input_tensor
|
36 |
cur_len = input_tensor.shape[1]
|
37 |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
38 |
-
with bar
|
39 |
while cur_len < max_len:
|
40 |
end = False
|
41 |
-
hidden =
|
42 |
-
next_token_seq =
|
43 |
event_name = ""
|
44 |
for i in range(max_token_seq):
|
45 |
-
mask =
|
46 |
if i == 0:
|
47 |
mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
|
48 |
if disable_patch_change:
|
@@ -56,9 +80,9 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
|
56 |
if param_name == "channel":
|
57 |
mask_ids = [i for i in mask_ids if i not in disable_channels]
|
58 |
mask[mask_ids] = 1
|
59 |
-
logits =
|
60 |
-
scores =
|
61 |
-
sample =
|
62 |
if i == 0:
|
63 |
next_token_seq = sample
|
64 |
eid = sample.item()
|
@@ -67,29 +91,30 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
|
67 |
break
|
68 |
event_name = tokenizer.id_events[eid]
|
69 |
else:
|
70 |
-
next_token_seq =
|
71 |
if len(tokenizer.events[event_name]) == i:
|
72 |
break
|
73 |
if next_token_seq.shape[1] < max_token_seq:
|
74 |
-
next_token_seq =
|
75 |
-
|
76 |
-
next_token_seq = next_token_seq
|
77 |
-
input_tensor =
|
78 |
cur_len += 1
|
79 |
bar.update(1)
|
80 |
-
yield next_token_seq.reshape(-1)
|
81 |
if end:
|
82 |
break
|
83 |
|
84 |
|
85 |
-
def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc
|
86 |
mid_seq = []
|
87 |
max_len = int(gen_events)
|
88 |
img_len = 1024
|
89 |
img = np.full((128 * 2, img_len, 3), 255, dtype=np.uint8)
|
90 |
state = {"t1": 0, "t": 0, "cur_pos": 0}
|
91 |
-
|
92 |
-
|
|
|
93 |
|
94 |
def draw_event(tokens):
|
95 |
if tokens[0] in tokenizer.id_events:
|
@@ -112,7 +137,7 @@ def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, t
|
|
112 |
img[:, -shift:] = 255
|
113 |
state["cur_pos"] += shift
|
114 |
t = t - state["cur_pos"]
|
115 |
-
img[p * 2:(p + 1) * 2, t: t + d] = colors[
|
116 |
|
117 |
def get_img():
|
118 |
t = state["t"] - state["cur_pos"]
|
@@ -135,7 +160,7 @@ def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, t
|
|
135 |
mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i, c, p]))
|
136 |
mid_seq = mid
|
137 |
mid = np.asarray(mid, dtype=np.int64)
|
138 |
-
if len(instruments) > 0
|
139 |
disable_patch_change = True
|
140 |
disable_channels = [i for i in range(16) if i not in patches]
|
141 |
elif mid is not None:
|
@@ -148,7 +173,7 @@ def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, t
|
|
148 |
draw_event(token_seq)
|
149 |
generator = generate(mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
150 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
151 |
-
disable_channels=disable_channels
|
152 |
for token_seq in generator:
|
153 |
mid_seq.append(token_seq)
|
154 |
draw_event(token_seq)
|
@@ -179,17 +204,16 @@ if __name__ == "__main__":
|
|
179 |
parser = argparse.ArgumentParser()
|
180 |
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
|
181 |
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
182 |
-
parser.add_argument("--
|
183 |
-
parser.add_argument("--max-gen", type=int, default=512, help="max")
|
184 |
-
soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
185 |
-
model_path = hf_hub_download(repo_id="skytnt/midi-model", filename="model.ckpt")
|
186 |
opt = parser.parse_args()
|
|
|
|
|
|
|
|
|
187 |
tokenizer = MIDITokenizer()
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
model.load_state_dict(state_dict, strict=False)
|
192 |
-
model.eval()
|
193 |
|
194 |
app = gr.Blocks()
|
195 |
with app:
|
@@ -199,39 +223,52 @@ if __name__ == "__main__":
|
|
199 |
"Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
|
200 |
"[Open In Colab]"
|
201 |
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
202 |
-
" for faster running"
|
|
|
203 |
|
204 |
tab_select = gr.Variable(value=0)
|
205 |
with gr.Tabs():
|
206 |
with gr.TabItem("instrument prompt") as tab1:
|
207 |
input_instruments = gr.Dropdown(label="instruments (auto if empty)", choices=list(patch2number.keys()),
|
208 |
-
multiselect=True, max_choices=
|
209 |
input_drum_kit = gr.Dropdown(label="drum kit", choices=list(drum_kits2number.keys()), type="value",
|
210 |
value="None")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
with gr.TabItem("midi prompt") as tab2:
|
212 |
input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary")
|
213 |
input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
|
214 |
step=1,
|
215 |
value=128)
|
|
|
|
|
216 |
|
217 |
tab1.select(lambda: 0, None, tab_select, queue=False)
|
218 |
tab2.select(lambda: 1, None, tab_select, queue=False)
|
219 |
input_gen_events = gr.Slider(label="generate n midi events", minimum=1, maximum=opt.max_gen,
|
220 |
step=1, value=opt.max_gen)
|
221 |
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
222 |
-
input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.
|
223 |
input_top_k = gr.Slider(label="top k", minimum=1, maximum=50, step=1, value=20)
|
224 |
-
input_allow_cc = gr.Checkbox(label="allow
|
225 |
-
input_amp = gr.Checkbox(label="enable amp", value=True)
|
226 |
run_btn = gr.Button("generate", variant="primary")
|
227 |
stop_btn = gr.Button("stop")
|
228 |
output_midi_seq = gr.Variable()
|
229 |
output_midi_img = gr.Image(label="output image")
|
230 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
231 |
-
output_audio = gr.Audio(label="output audio", format="
|
232 |
run_event = run_btn.click(run, [tab_select, input_instruments, input_drum_kit, input_midi, input_midi_events,
|
233 |
input_gen_events, input_temp, input_top_p, input_top_k,
|
234 |
-
input_allow_cc
|
235 |
[output_midi_seq, output_midi_img, output_midi, output_audio])
|
236 |
stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio], cancels=run_event, queue=False)
|
237 |
-
app.queue(
|
|
|
1 |
import argparse
|
2 |
import glob
|
3 |
+
import os
|
4 |
+
import os.path
|
5 |
+
from sys import exit
|
6 |
+
import shutil
|
7 |
import gradio as gr
|
8 |
import numpy as np
|
9 |
+
import onnxruntime as rt
|
10 |
+
import PIL
|
11 |
+
import PIL.ImageColor
|
12 |
import tqdm
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
|
15 |
import MIDI
|
|
|
|
|
16 |
from midi_synthesizer import synthesis
|
17 |
+
from midi_tokenizer import MIDITokenizer
|
18 |
+
|
19 |
+
def softmax(x, axis):
|
20 |
+
x_max = np.amax(x, axis=axis, keepdims=True)
|
21 |
+
exp_x_shifted = np.exp(x - x_max)
|
22 |
+
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
|
23 |
+
|
24 |
+
|
25 |
+
def sample_top_p_k(probs, p, k):
|
26 |
+
probs_idx = np.argsort(-probs, axis=-1)
|
27 |
+
probs_sort = np.take_along_axis(probs, probs_idx, -1)
|
28 |
+
probs_sum = np.cumsum(probs_sort, axis=-1)
|
29 |
+
mask = probs_sum - probs_sort > p
|
30 |
+
probs_sort[mask] = 0.0
|
31 |
+
mask = np.zeros(probs_sort.shape[-1])
|
32 |
+
mask[:k] = 1
|
33 |
+
probs_sort = probs_sort * mask
|
34 |
+
probs_sort /= np.sum(probs_sort, axis=-1, keepdims=True)
|
35 |
+
shape = probs_sort.shape
|
36 |
+
probs_sort_flat = probs_sort.reshape(-1, shape[-1])
|
37 |
+
probs_idx_flat = probs_idx.reshape(-1, shape[-1])
|
38 |
+
next_token = np.stack([np.random.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)])
|
39 |
+
next_token = next_token.reshape(*shape[:-1])
|
40 |
+
return next_token
|
41 |
+
|
42 |
|
|
|
43 |
def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
44 |
+
disable_patch_change=False, disable_control_change=False, disable_channels=None):
|
45 |
if disable_channels is not None:
|
46 |
disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
|
47 |
else:
|
48 |
disable_channels = []
|
49 |
max_token_seq = tokenizer.max_token_seq
|
50 |
if prompt is None:
|
51 |
+
input_tensor = np.full((1, max_token_seq), tokenizer.pad_id, dtype=np.int64)
|
52 |
input_tensor[0, 0] = tokenizer.bos_id # bos
|
53 |
else:
|
54 |
prompt = prompt[:, :max_token_seq]
|
55 |
if prompt.shape[-1] < max_token_seq:
|
56 |
prompt = np.pad(prompt, ((0, 0), (0, max_token_seq - prompt.shape[-1])),
|
57 |
mode="constant", constant_values=tokenizer.pad_id)
|
58 |
+
input_tensor = prompt
|
59 |
+
input_tensor = input_tensor[None, :, :]
|
60 |
cur_len = input_tensor.shape[1]
|
61 |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
62 |
+
with bar:
|
63 |
while cur_len < max_len:
|
64 |
end = False
|
65 |
+
hidden = model_base.run(None, {'x': input_tensor})[0][:, -1]
|
66 |
+
next_token_seq = np.empty((1, 0), dtype=np.int64)
|
67 |
event_name = ""
|
68 |
for i in range(max_token_seq):
|
69 |
+
mask = np.zeros(tokenizer.vocab_size, dtype=np.int64)
|
70 |
if i == 0:
|
71 |
mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
|
72 |
if disable_patch_change:
|
|
|
80 |
if param_name == "channel":
|
81 |
mask_ids = [i for i in mask_ids if i not in disable_channels]
|
82 |
mask[mask_ids] = 1
|
83 |
+
logits = model_token.run(None, {'x': next_token_seq, "hidden": hidden})[0][:, -1:]
|
84 |
+
scores = softmax(logits / temp, -1) * mask
|
85 |
+
sample = sample_top_p_k(scores, top_p, top_k)
|
86 |
if i == 0:
|
87 |
next_token_seq = sample
|
88 |
eid = sample.item()
|
|
|
91 |
break
|
92 |
event_name = tokenizer.id_events[eid]
|
93 |
else:
|
94 |
+
next_token_seq = np.concatenate([next_token_seq, sample], axis=1)
|
95 |
if len(tokenizer.events[event_name]) == i:
|
96 |
break
|
97 |
if next_token_seq.shape[1] < max_token_seq:
|
98 |
+
next_token_seq = np.pad(next_token_seq, ((0, 0), (0, max_token_seq - next_token_seq.shape[-1])),
|
99 |
+
mode="constant", constant_values=tokenizer.pad_id)
|
100 |
+
next_token_seq = next_token_seq[None, :, :]
|
101 |
+
input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
|
102 |
cur_len += 1
|
103 |
bar.update(1)
|
104 |
+
yield next_token_seq.reshape(-1)
|
105 |
if end:
|
106 |
break
|
107 |
|
108 |
|
109 |
+
def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
|
110 |
mid_seq = []
|
111 |
max_len = int(gen_events)
|
112 |
img_len = 1024
|
113 |
img = np.full((128 * 2, img_len, 3), 255, dtype=np.uint8)
|
114 |
state = {"t1": 0, "t": 0, "cur_pos": 0}
|
115 |
+
colors = ['navy', 'blue', 'deepskyblue', 'teal', 'green', 'lightgreen', 'lime', 'orange',
|
116 |
+
'brown', 'grey', 'red', 'pink', 'aqua', 'orchid', 'bisque', 'coral']
|
117 |
+
colors = [PIL.ImageColor.getrgb(color) for color in colors]
|
118 |
|
119 |
def draw_event(tokens):
|
120 |
if tokens[0] in tokenizer.id_events:
|
|
|
137 |
img[:, -shift:] = 255
|
138 |
state["cur_pos"] += shift
|
139 |
t = t - state["cur_pos"]
|
140 |
+
img[p * 2:(p + 1) * 2, t: t + d] = colors[c]
|
141 |
|
142 |
def get_img():
|
143 |
t = state["t"] - state["cur_pos"]
|
|
|
160 |
mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i, c, p]))
|
161 |
mid_seq = mid
|
162 |
mid = np.asarray(mid, dtype=np.int64)
|
163 |
+
if len(instruments) > 0:
|
164 |
disable_patch_change = True
|
165 |
disable_channels = [i for i in range(16) if i not in patches]
|
166 |
elif mid is not None:
|
|
|
173 |
draw_event(token_seq)
|
174 |
generator = generate(mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
175 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
176 |
+
disable_channels=disable_channels)
|
177 |
for token_seq in generator:
|
178 |
mid_seq.append(token_seq)
|
179 |
draw_event(token_seq)
|
|
|
204 |
parser = argparse.ArgumentParser()
|
205 |
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
|
206 |
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
207 |
+
parser.add_argument("--max-gen", type=int, default=256, help="max")
|
|
|
|
|
|
|
208 |
opt = parser.parse_args()
|
209 |
+
soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
210 |
+
model_base_path = hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx")
|
211 |
+
model_token_path = hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx")
|
212 |
+
|
213 |
tokenizer = MIDITokenizer()
|
214 |
+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
215 |
+
model_base = rt.InferenceSession(model_base_path, providers=providers)
|
216 |
+
model_token = rt.InferenceSession(model_token_path, providers=providers)
|
|
|
|
|
217 |
|
218 |
app = gr.Blocks()
|
219 |
with app:
|
|
|
223 |
"Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
|
224 |
"[Open In Colab]"
|
225 |
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
226 |
+
" for faster running and longer generation"
|
227 |
+
)
|
228 |
|
229 |
tab_select = gr.Variable(value=0)
|
230 |
with gr.Tabs():
|
231 |
with gr.TabItem("instrument prompt") as tab1:
|
232 |
input_instruments = gr.Dropdown(label="instruments (auto if empty)", choices=list(patch2number.keys()),
|
233 |
+
multiselect=True, max_choices=15, type="value")
|
234 |
input_drum_kit = gr.Dropdown(label="drum kit", choices=list(drum_kits2number.keys()), type="value",
|
235 |
value="None")
|
236 |
+
example1 = gr.Examples([
|
237 |
+
[[], "None"],
|
238 |
+
[["Acoustic Grand"], "None"],
|
239 |
+
[["Acoustic Grand", "Violin", "Viola", "Cello", "Contrabass", "Timpani"], "Orchestra"],
|
240 |
+
[["Acoustic Guitar(nylon)", "Acoustic Guitar(steel)", "Electric Guitar(jazz)",
|
241 |
+
"Electric Guitar(clean)", "Electric Guitar(muted)", "Overdriven Guitar", "Distortion Guitar",
|
242 |
+
"Electric Bass(finger)"], "Standard"],
|
243 |
+
[["Acoustic Grand", "String Ensemble 1", "Trombone", "Tuba", "Muted Trumpet", "French Horn", "Oboe",
|
244 |
+
"English Horn", "Bassoon", "Clarinet"], "Orchestra"]
|
245 |
+
|
246 |
+
], [input_instruments, input_drum_kit])
|
247 |
with gr.TabItem("midi prompt") as tab2:
|
248 |
input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary")
|
249 |
input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
|
250 |
step=1,
|
251 |
value=128)
|
252 |
+
example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
|
253 |
+
[input_midi, input_midi_events])
|
254 |
|
255 |
tab1.select(lambda: 0, None, tab_select, queue=False)
|
256 |
tab2.select(lambda: 1, None, tab_select, queue=False)
|
257 |
input_gen_events = gr.Slider(label="generate n midi events", minimum=1, maximum=opt.max_gen,
|
258 |
step=1, value=opt.max_gen)
|
259 |
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
260 |
+
input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
|
261 |
input_top_k = gr.Slider(label="top k", minimum=1, maximum=50, step=1, value=20)
|
262 |
+
input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
|
|
|
263 |
run_btn = gr.Button("generate", variant="primary")
|
264 |
stop_btn = gr.Button("stop")
|
265 |
output_midi_seq = gr.Variable()
|
266 |
output_midi_img = gr.Image(label="output image")
|
267 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
268 |
+
output_audio = gr.Audio(label="output audio", format="wav")
|
269 |
run_event = run_btn.click(run, [tab_select, input_instruments, input_drum_kit, input_midi, input_midi_events,
|
270 |
input_gen_events, input_temp, input_top_p, input_top_k,
|
271 |
+
input_allow_cc],
|
272 |
[output_midi_seq, output_midi_img, output_midi, output_audio])
|
273 |
stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio], cancels=run_event, queue=False)
|
274 |
+
app.queue(2).launch(server_port=opt.port, share=opt.share, inbrowser=True)
|
example/Bach--Fugue-in-D-Minor.mid
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1398121eb86a33e73f90ec84be71dac6abc0ddf11372ea7cdd9e01586938a56b
|
3 |
+
size 7720
|
example/Beethoven--Symphony-No5-in-C-Minor-Fate-Opus-67.mid
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:28ff6fdcd644e781d36411bf40ab7a1f4849adddbcd1040eaec22751c5ca99d2
|
3 |
+
size 87090
|
example/Chopin--Nocturne No. 9 in B Major, Opus 32 No.1, Andante Sostenuto.mid
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3a236e647ad9f5d0af680d3ca19d3b60f334c4bde6b4f86310f63405245c476e
|
3 |
+
size 13484
|
example/Mozart--Requiem, No.1..mid
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aa49bf4633401e16777fe47f6f53a494c2166f5101af6dafc60114932a59b9bd
|
3 |
+
size 14695
|
example/castle_in_the_sky.mid
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fa14aec6f1be15c4fddd0decc6d9152204f160d4e07e05d8d1dc9f209c309ff7
|
3 |
+
size 7957
|
example/eva-残酷な天使のテーゼ.mid
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e513487543d7e27ec5dc30f027302d2a3b5a3aaf9af554def1e5cd6a7a8d355a
|
3 |
+
size 17671
|
midi_model.py
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import torch.nn.functional as F
|
5 |
-
import tqdm
|
6 |
-
from transformers import LlamaModel, LlamaConfig
|
7 |
-
from transformers.modeling_utils import ModuleUtilsMixin
|
8 |
-
|
9 |
-
from midi_tokenizer import MIDITokenizer
|
10 |
-
|
11 |
-
|
12 |
-
class MIDIModel(nn.Module, ModuleUtilsMixin):
|
13 |
-
def __init__(self, tokenizer: MIDITokenizer, n_layer=12, n_head=16, n_embd=1024, n_inner=4096, flash=False,
|
14 |
-
*args, **kwargs):
|
15 |
-
super(MIDIModel, self).__init__()
|
16 |
-
self.tokenizer = tokenizer
|
17 |
-
self.net = LlamaModel(LlamaConfig(vocab_size=tokenizer.vocab_size,
|
18 |
-
hidden_size=n_embd, num_attention_heads=n_head,
|
19 |
-
num_hidden_layers=n_layer, intermediate_size=n_inner,
|
20 |
-
pad_token_id=tokenizer.pad_id, max_position_embeddings=4096))
|
21 |
-
self.net_token = LlamaModel(LlamaConfig(vocab_size=tokenizer.vocab_size,
|
22 |
-
hidden_size=n_embd, num_attention_heads=n_head // 4,
|
23 |
-
num_hidden_layers=n_layer // 4, intermediate_size=n_inner // 4,
|
24 |
-
pad_token_id=tokenizer.pad_id, max_position_embeddings=4096))
|
25 |
-
if flash:
|
26 |
-
self.net = self.net.to_bettertransformer()
|
27 |
-
self.net_token = self.net_token.to_bettertransformer()
|
28 |
-
self.lm_head = nn.Linear(n_embd, tokenizer.vocab_size, bias=False)
|
29 |
-
|
30 |
-
def forward_token(self, hidden_state, x=None):
|
31 |
-
"""
|
32 |
-
|
33 |
-
:param hidden_state: (batch_size, n_embd)
|
34 |
-
:param x: (batch_size, token_sequence_length)
|
35 |
-
:return: (batch_size, 1 + token_sequence_length, vocab_size)
|
36 |
-
"""
|
37 |
-
hidden_state = hidden_state.unsqueeze(1) # (batch_size, 1, n_embd)
|
38 |
-
if x is not None:
|
39 |
-
x = self.net_token.embed_tokens(x)
|
40 |
-
hidden_state = torch.cat([hidden_state, x], dim=1)
|
41 |
-
hidden_state = self.net_token.forward(inputs_embeds=hidden_state).last_hidden_state
|
42 |
-
return self.lm_head(hidden_state)
|
43 |
-
|
44 |
-
def forward(self, x):
|
45 |
-
"""
|
46 |
-
:param x: (batch_size, time_sequence_length, token_sequence_length)
|
47 |
-
:return: hidden (batch_size, time_sequence_length, n_embd)
|
48 |
-
"""
|
49 |
-
|
50 |
-
# merge token sequence
|
51 |
-
x = self.net.embed_tokens(x)
|
52 |
-
x = x.sum(dim=-2)
|
53 |
-
x = self.net.forward(inputs_embeds=x)
|
54 |
-
return x.last_hidden_state
|
55 |
-
|
56 |
-
def sample_top_p_k(self, probs, p, k):
|
57 |
-
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
58 |
-
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
59 |
-
mask = probs_sum - probs_sort > p
|
60 |
-
probs_sort[mask] = 0.0
|
61 |
-
mask = torch.zeros(probs_sort.shape[-1], device=probs_sort.device)
|
62 |
-
mask[:k] = 1
|
63 |
-
probs_sort = probs_sort * mask
|
64 |
-
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
65 |
-
shape = probs_sort.shape
|
66 |
-
next_token = torch.multinomial(probs_sort.reshape(-1, shape[-1]), num_samples=1).reshape(*shape[:-1], 1)
|
67 |
-
next_token = torch.gather(probs_idx, -1, next_token).reshape(*shape[:-1])
|
68 |
-
return next_token
|
69 |
-
|
70 |
-
@torch.inference_mode()
|
71 |
-
def generate(self, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20, amp=True):
|
72 |
-
tokenizer = self.tokenizer
|
73 |
-
max_token_seq = tokenizer.max_token_seq
|
74 |
-
if prompt is None:
|
75 |
-
input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=self.device)
|
76 |
-
input_tensor[0, 0] = tokenizer.bos_id # bos
|
77 |
-
else:
|
78 |
-
prompt = prompt[:, :max_token_seq]
|
79 |
-
if prompt.shape[-1] < max_token_seq:
|
80 |
-
prompt = np.pad(prompt, ((0, 0), (0, max_token_seq - prompt.shape[-1])),
|
81 |
-
mode="constant", constant_values=tokenizer.pad_id)
|
82 |
-
input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=self.device)
|
83 |
-
input_tensor = input_tensor.unsqueeze(0)
|
84 |
-
cur_len = input_tensor.shape[1]
|
85 |
-
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
86 |
-
with bar, torch.cuda.amp.autocast(enabled=amp):
|
87 |
-
while cur_len < max_len:
|
88 |
-
end = False
|
89 |
-
hidden = self.forward(input_tensor)[0, -1].unsqueeze(0)
|
90 |
-
next_token_seq = None
|
91 |
-
event_name = ""
|
92 |
-
for i in range(max_token_seq):
|
93 |
-
mask = torch.zeros(tokenizer.vocab_size, dtype=torch.int64, device=self.device)
|
94 |
-
if i == 0:
|
95 |
-
mask[list(tokenizer.event_ids.values()) + [tokenizer.eos_id]] = 1
|
96 |
-
else:
|
97 |
-
param_name = tokenizer.events[event_name][i - 1]
|
98 |
-
mask[tokenizer.parameter_ids[param_name]] = 1
|
99 |
-
|
100 |
-
logits = self.forward_token(hidden, next_token_seq)[:, -1:]
|
101 |
-
scores = torch.softmax(logits / temp, dim=-1) * mask
|
102 |
-
sample = self.sample_top_p_k(scores, top_p, top_k)
|
103 |
-
if i == 0:
|
104 |
-
next_token_seq = sample
|
105 |
-
eid = sample.item()
|
106 |
-
if eid == tokenizer.eos_id:
|
107 |
-
end = True
|
108 |
-
break
|
109 |
-
event_name = tokenizer.id_events[eid]
|
110 |
-
else:
|
111 |
-
next_token_seq = torch.cat([next_token_seq, sample], dim=1)
|
112 |
-
if len(tokenizer.events[event_name]) == i:
|
113 |
-
break
|
114 |
-
if next_token_seq.shape[1] < max_token_seq:
|
115 |
-
next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]),
|
116 |
-
"constant", value=tokenizer.pad_id)
|
117 |
-
next_token_seq = next_token_seq.unsqueeze(1)
|
118 |
-
input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
|
119 |
-
cur_len += 1
|
120 |
-
bar.update(1)
|
121 |
-
if end:
|
122 |
-
break
|
123 |
-
return input_tensor[0].cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
Pillow
|
2 |
numpy
|
3 |
-
|
4 |
-
transformers
|
5 |
gradio
|
6 |
pyfluidsynth
|
|
|
1 |
Pillow
|
2 |
numpy
|
3 |
+
onnxruntime-gpu
|
|
|
4 |
gradio
|
5 |
pyfluidsynth
|