skytnt commited on
Commit
9d7618f
·
1 Parent(s): 43a6dd3

torch seems slow, bring back onnx

Browse files
Files changed (2) hide show
  1. app.py +61 -45
  2. requirements.txt +1 -3
app.py CHANGED
@@ -1,54 +1,79 @@
1
  import argparse
2
  import glob
3
- import json
4
- import os
5
  import time
 
6
 
7
  import gradio as gr
8
  import numpy as np
9
- import torch
10
-
11
- import torch.nn.functional as F
12
  import tqdm
 
 
13
 
14
  import MIDI
15
- from midi_model import MIDIModel
16
- from midi_tokenizer import MIDITokenizer
17
  from midi_synthesizer import synthesis
18
- from huggingface_hub import hf_hub_download
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  in_space = os.getenv("SYSTEM") == "spaces"
22
 
23
 
24
- @torch.inference_mode()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
26
- disable_patch_change=False, disable_control_change=False, disable_channels=None, amp=True, generator=None):
27
  if disable_channels is not None:
28
  disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
29
  else:
30
  disable_channels = []
 
 
31
  max_token_seq = tokenizer.max_token_seq
32
  if prompt is None:
33
- input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=model.device)
34
  input_tensor[0, 0] = tokenizer.bos_id # bos
35
  else:
36
  prompt = prompt[:, :max_token_seq]
37
  if prompt.shape[-1] < max_token_seq:
38
  prompt = np.pad(prompt, ((0, 0), (0, max_token_seq - prompt.shape[-1])),
39
  mode="constant", constant_values=tokenizer.pad_id)
40
- input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=model.device)
41
- input_tensor = input_tensor.unsqueeze(0)
42
  cur_len = input_tensor.shape[1]
43
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
44
- with bar, torch.amp.autocast(device_type=model.device, enabled=amp):
45
  while cur_len < max_len:
46
  end = False
47
- hidden = model.forward(input_tensor)[0, -1].unsqueeze(0)
48
- next_token_seq = None
49
  event_name = ""
50
  for i in range(max_token_seq):
51
- mask = torch.zeros(tokenizer.vocab_size, dtype=torch.int64, device=model.device)
52
  if i == 0:
53
  mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
54
  if disable_patch_change:
@@ -62,9 +87,9 @@ def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
62
  if param_name == "channel":
63
  mask_ids = [i for i in mask_ids if i not in disable_channels]
64
  mask[mask_ids] = 1
65
- logits = model.forward_token(hidden, next_token_seq)[:, -1:]
66
- scores = torch.softmax(logits / temp, dim=-1) * mask
67
- sample = model.sample_top_p_k(scores, top_p, top_k, generator=generator)
68
  if i == 0:
69
  next_token_seq = sample
70
  eid = sample.item()
@@ -73,17 +98,17 @@ def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
73
  break
74
  event_name = tokenizer.id_events[eid]
75
  else:
76
- next_token_seq = torch.cat([next_token_seq, sample], dim=1)
77
  if len(tokenizer.events[event_name]) == i:
78
  break
79
  if next_token_seq.shape[1] < max_token_seq:
80
- next_token_seq = F.pad(next_token_seq, (0, max_token_seq - next_token_seq.shape[1]),
81
- "constant", value=tokenizer.pad_id)
82
- next_token_seq = next_token_seq.unsqueeze(1)
83
- input_tensor = torch.cat([input_tensor, next_token_seq], dim=1)
84
  cur_len += 1
85
  bar.update(1)
86
- yield next_token_seq.reshape(-1).cpu().numpy()
87
  if end:
88
  break
89
 
@@ -104,7 +129,7 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, midi_opt,
104
  max_len = gen_events
105
  if seed_rand:
106
  seed = np.random.randint(0, MAX_SEED)
107
- generator = torch.Generator(device).manual_seed(seed)
108
  disable_patch_change = False
109
  disable_channels = None
110
  if tab == 0:
@@ -135,16 +160,14 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, midi_opt,
135
  for token_seq in mid:
136
  mid_seq.append(token_seq.tolist())
137
  max_len += len(mid)
138
-
139
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
140
  init_msgs = [create_msg("visualizer_clear", None), create_msg("visualizer_append", events)]
141
  t = time.time() + 1
142
  yield mid_seq, None, None, seed, send_msgs(init_msgs)
143
  model = models[model_name]
144
- amp = device == "cuda"
145
  midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
146
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
147
- disable_channels=disable_channels, amp=amp, generator=generator)
148
  events = []
149
  for i, token_seq in enumerate(midi_generator):
150
  token_seq = token_seq.tolist()
@@ -222,21 +245,15 @@ if __name__ == "__main__":
222
  "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
223
  "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
224
  }
225
- device = "cuda" if torch.cuda.is_available() else "cpu"
226
- if device=="cuda": # flash attn
227
- torch.backends.cuda.enable_mem_efficient_sdp(True)
228
- torch.backends.cuda.enable_flash_sdp(True)
229
  models = {}
230
  tokenizer = MIDITokenizer()
 
231
  for name, (repo_id, path) in models_info.items():
232
-
233
- model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
234
- model = MIDIModel(tokenizer).to(device=device)
235
- ckpt = torch.load(model_path, weights_only=True)
236
- state_dict = ckpt.get("state_dict", ckpt)
237
- model.load_state_dict(state_dict, strict=False)
238
- model.eval()
239
- models[name] = model
240
 
241
  load_javascript()
242
  app = gr.Blocks()
@@ -248,8 +265,7 @@ if __name__ == "__main__":
248
  "[Open In Colab]"
249
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
250
  " for faster running and longer generation\n\n"
251
- "**Update v1.2**: Optimise the tokenizer and dataset\n\n"
252
- f"Device: {device}"
253
  )
254
  js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
255
  js_msg.change(None, [js_msg], [], js="""
@@ -319,4 +335,4 @@ if __name__ == "__main__":
319
  [output_midi_seq, output_midi, output_audio, input_seed, js_msg],
320
  concurrency_limit=3)
321
  stop_btn.click(cancel_run, [output_midi_seq], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
322
- app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
 
1
  import argparse
2
  import glob
3
+ import os.path
 
4
  import time
5
+ import uuid
6
 
7
  import gradio as gr
8
  import numpy as np
9
+ import onnxruntime as rt
 
 
10
  import tqdm
11
+ import json
12
+ from huggingface_hub import hf_hub_download
13
 
14
  import MIDI
 
 
15
  from midi_synthesizer import synthesis
16
+ from midi_tokenizer import MIDITokenizer
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  in_space = os.getenv("SYSTEM") == "spaces"
20
 
21
 
22
+ def softmax(x, axis):
23
+ x_max = np.amax(x, axis=axis, keepdims=True)
24
+ exp_x_shifted = np.exp(x - x_max)
25
+ return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
26
+
27
+
28
+ def sample_top_p_k(probs, p, k, generator=None):
29
+ if generator is None:
30
+ generator = np.random
31
+ probs_idx = np.argsort(-probs, axis=-1)
32
+ probs_sort = np.take_along_axis(probs, probs_idx, -1)
33
+ probs_sum = np.cumsum(probs_sort, axis=-1)
34
+ mask = probs_sum - probs_sort > p
35
+ probs_sort[mask] = 0.0
36
+ mask = np.zeros(probs_sort.shape[-1])
37
+ mask[:k] = 1
38
+ probs_sort = probs_sort * mask
39
+ probs_sort /= np.sum(probs_sort, axis=-1, keepdims=True)
40
+ shape = probs_sort.shape
41
+ probs_sort_flat = probs_sort.reshape(-1, shape[-1])
42
+ probs_idx_flat = probs_idx.reshape(-1, shape[-1])
43
+ next_token = np.stack([generator.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)])
44
+ next_token = next_token.reshape(*shape[:-1])
45
+ return next_token
46
+
47
+
48
  def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
49
+ disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
50
  if disable_channels is not None:
51
  disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
52
  else:
53
  disable_channels = []
54
+ if generator is None:
55
+ generator = np.random
56
  max_token_seq = tokenizer.max_token_seq
57
  if prompt is None:
58
+ input_tensor = np.full((1, max_token_seq), tokenizer.pad_id, dtype=np.int64)
59
  input_tensor[0, 0] = tokenizer.bos_id # bos
60
  else:
61
  prompt = prompt[:, :max_token_seq]
62
  if prompt.shape[-1] < max_token_seq:
63
  prompt = np.pad(prompt, ((0, 0), (0, max_token_seq - prompt.shape[-1])),
64
  mode="constant", constant_values=tokenizer.pad_id)
65
+ input_tensor = prompt
66
+ input_tensor = input_tensor[None, :, :]
67
  cur_len = input_tensor.shape[1]
68
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
69
+ with bar:
70
  while cur_len < max_len:
71
  end = False
72
+ hidden = model[0].run(None, {'x': input_tensor})[0][:, -1]
73
+ next_token_seq = np.empty((1, 0), dtype=np.int64)
74
  event_name = ""
75
  for i in range(max_token_seq):
76
+ mask = np.zeros(tokenizer.vocab_size, dtype=np.int64)
77
  if i == 0:
78
  mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
79
  if disable_patch_change:
 
87
  if param_name == "channel":
88
  mask_ids = [i for i in mask_ids if i not in disable_channels]
89
  mask[mask_ids] = 1
90
+ logits = model[1].run(None, {'x': next_token_seq, "hidden": hidden})[0][:, -1:]
91
+ scores = softmax(logits / temp, -1) * mask
92
+ sample = sample_top_p_k(scores, top_p, top_k, generator)
93
  if i == 0:
94
  next_token_seq = sample
95
  eid = sample.item()
 
98
  break
99
  event_name = tokenizer.id_events[eid]
100
  else:
101
+ next_token_seq = np.concatenate([next_token_seq, sample], axis=1)
102
  if len(tokenizer.events[event_name]) == i:
103
  break
104
  if next_token_seq.shape[1] < max_token_seq:
105
+ next_token_seq = np.pad(next_token_seq, ((0, 0), (0, max_token_seq - next_token_seq.shape[-1])),
106
+ mode="constant", constant_values=tokenizer.pad_id)
107
+ next_token_seq = next_token_seq[None, :, :]
108
+ input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
109
  cur_len += 1
110
  bar.update(1)
111
+ yield next_token_seq.reshape(-1)
112
  if end:
113
  break
114
 
 
129
  max_len = gen_events
130
  if seed_rand:
131
  seed = np.random.randint(0, MAX_SEED)
132
+ generator = np.random.RandomState(seed)
133
  disable_patch_change = False
134
  disable_channels = None
135
  if tab == 0:
 
160
  for token_seq in mid:
161
  mid_seq.append(token_seq.tolist())
162
  max_len += len(mid)
 
163
  events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
164
  init_msgs = [create_msg("visualizer_clear", None), create_msg("visualizer_append", events)]
165
  t = time.time() + 1
166
  yield mid_seq, None, None, seed, send_msgs(init_msgs)
167
  model = models[model_name]
 
168
  midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
169
  disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
170
+ disable_channels=disable_channels, generator=generator)
171
  events = []
172
  for i, token_seq in enumerate(midi_generator):
173
  token_seq = token_seq.tolist()
 
245
  "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
246
  "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
247
  }
 
 
 
 
248
  models = {}
249
  tokenizer = MIDITokenizer()
250
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
251
  for name, (repo_id, path) in models_info.items():
252
+ model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
253
+ model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
254
+ model_base = rt.InferenceSession(model_base_path, providers=providers)
255
+ model_token = rt.InferenceSession(model_token_path, providers=providers)
256
+ models[name] = [model_base, model_token]
 
 
 
257
 
258
  load_javascript()
259
  app = gr.Blocks()
 
265
  "[Open In Colab]"
266
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
267
  " for faster running and longer generation\n\n"
268
+ "**Update v1.2**: Optimise the tokenizer and dataset"
 
269
  )
270
  js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
271
  js_msg.change(None, [js_msg], [], js="""
 
335
  [output_midi_seq, output_midi, output_audio, input_seed, js_msg],
336
  concurrency_limit=3)
337
  stop_btn.click(cancel_run, [output_midi_seq], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
338
+ app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
requirements.txt CHANGED
@@ -1,8 +1,6 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu124
2
  Pillow
3
  numpy
4
- torch
5
- transformers>=4.36
6
  gradio==4.43.0
7
  pyfluidsynth
8
  tqdm
 
 
1
  Pillow
2
  numpy
3
+ onnxruntime-gpu
 
4
  gradio==4.43.0
5
  pyfluidsynth
6
  tqdm