skytnt commited on
Commit
2e60fd4
·
1 Parent(s): 5bef524

add kv cache for onnx

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +2 -1
  3. app_onnx.py +77 -18
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: red
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 4.43.0
8
- app_file: app.py
9
  pinned: true
10
  license: apache-2.0
11
  ---
 
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 4.43.0
8
+ app_file: app_onnx.py
9
  pinned: true
10
  license: apache-2.0
11
  ---
app.py CHANGED
@@ -415,7 +415,8 @@ if __name__ == "__main__":
415
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
416
  " or [download windows app](https://github.com/SkyTNT/midi-model/releases)"
417
  " for unlimited generation\n\n"
418
- "**Update v1.3**: MIDITokenizerV2 and new MidiVisualizer"
 
419
  )
420
  js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
421
  js_msg.change(None, [js_msg], [], js="""
 
415
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
416
  " or [download windows app](https://github.com/SkyTNT/midi-model/releases)"
417
  " for unlimited generation\n\n"
418
+ "**Update v1.3**: MIDITokenizerV2 and new MidiVisualizer\n\n"
419
+ "The current **best** model: generic pretrain model (tv2o-medium) by skytnt"
420
  )
421
  js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
422
  js_msg.change(None, [js_msg], [], js="""
app_onnx.py CHANGED
@@ -47,6 +47,37 @@ def sample_top_p_k(probs, p, k, generator=None):
47
  return next_token
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
51
  disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
52
  tokenizer = model[2]
@@ -77,12 +108,31 @@ def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98
77
  input_tensor = prompt
78
  cur_len = input_tensor.shape[1]
79
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
 
 
 
 
 
 
 
80
  with bar:
81
  while cur_len < max_len:
82
  end = [False] * batch_size
83
- hidden = model[0].run(None, {'x': input_tensor})[0][:, -1]
84
- next_token_seq = np.empty((batch_size, 0), dtype=np.int64)
 
 
 
 
 
 
 
 
 
 
85
  event_names = [""] * batch_size
 
 
86
  for i in range(max_token_seq):
87
  mask = np.zeros((batch_size, tokenizer.vocab_size), dtype=np.int64)
88
  for b in range(batch_size):
@@ -107,7 +157,24 @@ def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98
107
  mask_ids = [i for i in mask_ids if i not in disable_channels]
108
  mask[b, mask_ids] = 1
109
  mask = mask[:, None, :]
110
- logits = model[1].run(None, {'x': next_token_seq, "hidden": hidden})[0][:, -1:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  scores = softmax(logits / temp, -1) * mask
112
  samples = sample_top_p_k(scores, top_p, top_k, generator)
113
  if i == 0:
@@ -130,6 +197,7 @@ def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98
130
  mode="constant", constant_values=tokenizer.pad_id)
131
  next_token_seq = next_token_seq[:, None, :]
132
  input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
 
133
  cur_len += 1
134
  bar.update(1)
135
  yield next_token_seq[:, 0]
@@ -145,24 +213,13 @@ def send_msgs(msgs):
145
  return json.dumps(msgs)
146
 
147
 
148
- def calc_time(x):
149
- return 5.849e-5*x**2 + 0.04781*x + 0.1168
150
-
151
  def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
152
  time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
153
  remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
154
- if tab == 0:
155
- start_events = 1
156
- elif tab == 1 and mid is not None:
157
- start_events = midi_events
158
- elif tab == 2 and mid_seq is not None:
159
- start_events = len(mid_seq[0])
160
- else:
161
- start_events = 1
162
- t = calc_time(start_events + gen_events) - calc_time(start_events) + 5
163
  if "large" in model_name:
164
- t *= 2
165
- return t
166
 
167
 
168
  @spaces.GPU(duration=get_duration)
@@ -428,6 +485,7 @@ if __name__ == "__main__":
428
  }
429
  models = {}
430
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
 
431
 
432
  for name, (repo_id, path, config, loras) in models_info.items():
433
  model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
@@ -451,7 +509,8 @@ if __name__ == "__main__":
451
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
452
  " or [download windows app](https://github.com/SkyTNT/midi-model/releases)"
453
  " for unlimited generation\n\n"
454
- "**Update v1.3**: MIDITokenizerV2 and new MidiVisualizer"
 
455
  )
456
  js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
457
  js_msg.change(None, [js_msg], [], js="""
 
47
  return next_token
48
 
49
 
50
+ def apply_io_binding(model: rt.InferenceSession, inputs, outputs, batch_size, past_len, cur_len):
51
+ io_binding = model.io_binding()
52
+ for input_ in model.get_inputs():
53
+ name = input_.name
54
+ if name.startswith("past_key_values"):
55
+ present_name = name.replace("past_key_values", "present")
56
+ if present_name in outputs:
57
+ v = outputs[present_name]
58
+ else:
59
+ v = rt.OrtValue.ortvalue_from_shape_and_type(
60
+ (batch_size, input_.shape[1], past_len, input_.shape[3]),
61
+ element_type=np.float32,
62
+ device_type=device)
63
+ inputs[name] = v
64
+ else:
65
+ v = inputs[name]
66
+ io_binding.bind_ortvalue_input(name, v)
67
+
68
+ for output in model.get_outputs():
69
+ name = output.name
70
+ if name.startswith("present"):
71
+ v = rt.OrtValue.ortvalue_from_shape_and_type(
72
+ (batch_size, output.shape[1], cur_len, output.shape[3]),
73
+ element_type=np.float32,
74
+ device_type=device)
75
+ outputs[name] = v
76
+ else:
77
+ v = outputs[name]
78
+ io_binding.bind_ortvalue_output(name, v)
79
+ return io_binding
80
+
81
  def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
82
  disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
83
  tokenizer = model[2]
 
108
  input_tensor = prompt
109
  cur_len = input_tensor.shape[1]
110
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
111
+ model0_inputs = {}
112
+ model0_outputs = {}
113
+ emb_size = 1024
114
+ for output in model[0].get_outputs():
115
+ if output.name == "hidden":
116
+ emb_size = output.shape[2]
117
+ past_len = 0
118
  with bar:
119
  while cur_len < max_len:
120
  end = [False] * batch_size
121
+ model0_inputs["x"] = rt.OrtValue.ortvalue_from_numpy(input_tensor[:, past_len:], device_type=device)
122
+ model0_outputs["hidden"] = rt.OrtValue.ortvalue_from_shape_and_type(
123
+ (batch_size, cur_len - past_len, emb_size),
124
+ element_type=np.float32,
125
+ device_type=device)
126
+ io_binding = apply_io_binding(model[0], model0_inputs, model0_outputs, batch_size, past_len, cur_len)
127
+ io_binding.synchronize_inputs()
128
+ model[0].run_with_iobinding(io_binding)
129
+ io_binding.synchronize_outputs()
130
+
131
+ hidden = model0_outputs["hidden"].numpy()[:, -1:]
132
+ next_token_seq = np.zeros((batch_size, 0), dtype=np.int64)
133
  event_names = [""] * batch_size
134
+ model1_inputs = {"hidden": rt.OrtValue.ortvalue_from_numpy(hidden, device_type=device)}
135
+ model1_outputs = {}
136
  for i in range(max_token_seq):
137
  mask = np.zeros((batch_size, tokenizer.vocab_size), dtype=np.int64)
138
  for b in range(batch_size):
 
157
  mask_ids = [i for i in mask_ids if i not in disable_channels]
158
  mask[b, mask_ids] = 1
159
  mask = mask[:, None, :]
160
+ x = next_token_seq
161
+ if i != 0:
162
+ # cached
163
+ if i == 1:
164
+ hidden = np.zeros((batch_size, 0, emb_size), dtype=np.float32)
165
+ model1_inputs["hidden"] = rt.OrtValue.ortvalue_from_numpy(hidden, device_type=device)
166
+ x = x[:, -1:]
167
+ model1_inputs["x"] = rt.OrtValue.ortvalue_from_numpy(x, device_type=device)
168
+ model1_outputs["y"] = rt.OrtValue.ortvalue_from_shape_and_type(
169
+ (batch_size, 1, tokenizer.vocab_size),
170
+ element_type=np.float32,
171
+ device_type=device
172
+ )
173
+ io_binding = apply_io_binding(model[1], model1_inputs, model1_outputs, batch_size, i, i+1)
174
+ io_binding.synchronize_inputs()
175
+ model[1].run_with_iobinding(io_binding)
176
+ io_binding.synchronize_outputs()
177
+ logits = model1_outputs["y"].numpy()
178
  scores = softmax(logits / temp, -1) * mask
179
  samples = sample_top_p_k(scores, top_p, top_k, generator)
180
  if i == 0:
 
197
  mode="constant", constant_values=tokenizer.pad_id)
198
  next_token_seq = next_token_seq[:, None, :]
199
  input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
200
+ past_len = cur_len
201
  cur_len += 1
202
  bar.update(1)
203
  yield next_token_seq[:, 0]
 
213
  return json.dumps(msgs)
214
 
215
 
 
 
 
216
  def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
217
  time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
218
  remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
219
+ t = gen_events // 23
 
 
 
 
 
 
 
 
220
  if "large" in model_name:
221
+ t = gen_events // 14
222
+ return t + 5
223
 
224
 
225
  @spaces.GPU(duration=get_duration)
 
485
  }
486
  models = {}
487
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
488
+ device = "cuda"
489
 
490
  for name, (repo_id, path, config, loras) in models_info.items():
491
  model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
 
509
  "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
510
  " or [download windows app](https://github.com/SkyTNT/midi-model/releases)"
511
  " for unlimited generation\n\n"
512
+ "**Update v1.3**: MIDITokenizerV2 and new MidiVisualizer\n\n"
513
+ "The current **best** model: generic pretrain model (tv2o-medium) by skytnt"
514
  )
515
  js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
516
  js_msg.change(None, [js_msg], [], js="""