asigalov61 commited on
Commit
41327e9
1 Parent(s): 4e4340d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -126
app.py CHANGED
@@ -13,76 +13,6 @@ from midi_synthesizer import synthesis
13
 
14
  in_space = os.getenv("SYSTEM") == "spaces"
15
 
16
-
17
- def find_midi():
18
- if disable_channels is not None:
19
- disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
20
- else:
21
- disable_channels = []
22
- max_token_seq = tokenizer.max_token_seq
23
- if prompt is None:
24
- input_tensor = np.full((1, max_token_seq), tokenizer.pad_id, dtype=np.int64)
25
- input_tensor[0, 0] = tokenizer.bos_id # bos
26
- else:
27
- prompt = prompt[:, :max_token_seq]
28
- if prompt.shape[-1] < max_token_seq:
29
- prompt = np.pad(prompt, ((0, 0), (0, max_token_seq - prompt.shape[-1])),
30
- mode="constant", constant_values=tokenizer.pad_id)
31
- input_tensor = prompt
32
- input_tensor = input_tensor[None, :, :]
33
- cur_len = input_tensor.shape[1]
34
- bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
35
- with bar:
36
- while cur_len < max_len:
37
- end = False
38
- hidden = model[0].run(None, {'x': input_tensor})[0][:, -1]
39
- next_token_seq = np.empty((1, 0), dtype=np.int64)
40
- event_name = ""
41
- for i in range(max_token_seq):
42
- mask = np.zeros(tokenizer.vocab_size, dtype=np.int64)
43
- if i == 0:
44
- mask_ids = list(tokenizer.event_ids.values()) + [tokenizer.eos_id]
45
- if disable_patch_change:
46
- mask_ids.remove(tokenizer.event_ids["patch_change"])
47
- if disable_control_change:
48
- mask_ids.remove(tokenizer.event_ids["control_change"])
49
- mask[mask_ids] = 1
50
- else:
51
- param_name = tokenizer.events[event_name][i - 1]
52
- mask_ids = tokenizer.parameter_ids[param_name]
53
- if param_name == "channel":
54
- mask_ids = [i for i in mask_ids if i not in disable_channels]
55
- mask[mask_ids] = 1
56
- logits = model[1].run(None, {'x': next_token_seq, "hidden": hidden})[0][:, -1:]
57
- scores = softmax(logits / temp, -1) * mask
58
- sample = sample_top_p_k(scores, top_p, top_k)
59
- if i == 0:
60
- next_token_seq = sample
61
- eid = sample.item()
62
- if eid == tokenizer.eos_id:
63
- end = True
64
- break
65
- event_name = tokenizer.id_events[eid]
66
- else:
67
- next_token_seq = np.concatenate([next_token_seq, sample], axis=1)
68
- if len(tokenizer.events[event_name]) == i:
69
- break
70
- if next_token_seq.shape[1] < max_token_seq:
71
- next_token_seq = np.pad(next_token_seq, ((0, 0), (0, max_token_seq - next_token_seq.shape[-1])),
72
- mode="constant", constant_values=tokenizer.pad_id)
73
- next_token_seq = next_token_seq[None, :, :]
74
- input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
75
- cur_len += 1
76
- bar.update(1)
77
- yield next_token_seq.reshape(-1)
78
- if end:
79
- break
80
-
81
-
82
- def create_msg(name, data):
83
- return {"name": name, "data": data}
84
-
85
-
86
  def run(search_prompt, mid=None):
87
  mid_seq = []
88
 
@@ -95,24 +25,11 @@ def run(search_prompt, mid=None):
95
 
96
  elif mid is not None:
97
  mid_seq = MIDI.midi2score(mid)
98
-
99
- init_msgs = [create_msg("visualizer_clear", None)]
100
- for event in mid_seq:
101
- if event[0] == 'note':
102
- init_msgs.append(create_msg("visualizer_append", event))
103
- yield mid_seq, None, None, init_msgs
104
-
105
- # j = 0
106
- # for i in range(len(mid_seq)-1):
107
- # if mid_seq[i][0] == 'note':
108
- # j += 1
109
- # yield mid_seq, None, None, [create_msg("visualizer_append", mid_seq[i]), create_msg("progress", [j + 1, len(mid_seq)])]
110
-
111
-
112
  with open(f"output.mid", 'wb') as f:
113
  f.write(MIDI.score2midi([mid_seq_ticks, mid_seq]))
114
  audio = synthesis(MIDI.score2opus([mid_seq_ticks, mid_seq]), soundfont_path)
115
- yield mid_seq, "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
116
 
117
 
118
  def cancel_run(mid_seq):
@@ -120,40 +37,9 @@ def cancel_run(mid_seq):
120
  return None, None
121
 
122
  with open(f"output.mid", 'wb') as f:
123
- f.write(MIDI.score2midi(mid_seq))
124
- audio = synthesis(MIDI.score2opus(mid_seq), soundfont_path)
125
- return "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
126
-
127
-
128
- def load_javascript(dir="javascript"):
129
- scripts_list = glob.glob(f"{dir}/*.js")
130
- javascript = ""
131
- for path in scripts_list:
132
- with open(path, "r", encoding="utf8") as jsfile:
133
- javascript += f"\n<!-- {path} --><script>{jsfile.read()}</script>"
134
- template_response_ori = gr.routes.templates.TemplateResponse
135
-
136
- def template_response(*args, **kwargs):
137
- res = template_response_ori(*args, **kwargs)
138
- res.body = res.body.replace(
139
- b'</head>', f'{javascript}</head>'.encode("utf8"))
140
- res.init_headers()
141
- return res
142
-
143
- gr.routes.templates.TemplateResponse = template_response
144
-
145
-
146
- class JSMsgReceiver(gr.HTML):
147
- def __init__(self, **kwargs):
148
- super().__init__(elem_id="msg_receiver", visible=False, **kwargs)
149
-
150
- def postprocess(self, y):
151
- if y:
152
- y = f"<p>{json.dumps(y)}</p>"
153
- return super().postprocess(y)
154
-
155
- def get_block_name(self) -> str:
156
- return "html"
157
 
158
  if __name__ == "__main__":
159
  parser = argparse.ArgumentParser()
@@ -176,8 +62,6 @@ if __name__ == "__main__":
176
  meta_data = pickle.load(f)
177
  print('Done!')
178
 
179
-
180
- load_javascript()
181
  app = gr.Blocks()
182
  with app:
183
  gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Search</h1>")
@@ -189,8 +73,6 @@ if __name__ == "__main__":
189
  " for faster running and longer generation"
190
  )
191
 
192
- js_msg = JSMsgReceiver()
193
-
194
  with gr.Tabs():
195
  with gr.TabItem("instrument prompt") as tab1:
196
 
@@ -206,10 +88,9 @@ if __name__ == "__main__":
206
  search_btn = gr.Button("search", variant="primary")
207
  stop_btn = gr.Button("stop and output")
208
  output_midi_seq = gr.Textbox()
209
- output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
210
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
211
  output_midi = gr.File(label="output midi", file_types=[".mid"])
212
  run_event = search_btn.click(run, [search_prompt],
213
- [output_midi_seq, output_midi, output_audio, js_msg])
214
- stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
215
  app.queue(1).launch(server_port=opt.port, share=opt.share, inbrowser=True)
 
13
 
14
  in_space = os.getenv("SYSTEM") == "spaces"
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def run(search_prompt, mid=None):
17
  mid_seq = []
18
 
 
25
 
26
  elif mid is not None:
27
  mid_seq = MIDI.midi2score(mid)
28
+
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  with open(f"output.mid", 'wb') as f:
30
  f.write(MIDI.score2midi([mid_seq_ticks, mid_seq]))
31
  audio = synthesis(MIDI.score2opus([mid_seq_ticks, mid_seq]), soundfont_path)
32
+ yield mid_seq, "output.mid", (44100, audio)
33
 
34
 
35
  def cancel_run(mid_seq):
 
37
  return None, None
38
 
39
  with open(f"output.mid", 'wb') as f:
40
+ f.write(MIDI.score2midi([1000, mid_seq]))
41
+ audio = synthesis(MIDI.score2opus([1000, mid_seq]), soundfont_path)
42
+ return "output.mid", (44100, audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  if __name__ == "__main__":
45
  parser = argparse.ArgumentParser()
 
62
  meta_data = pickle.load(f)
63
  print('Done!')
64
 
 
 
65
  app = gr.Blocks()
66
  with app:
67
  gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Search</h1>")
 
73
  " for faster running and longer generation"
74
  )
75
 
 
 
76
  with gr.Tabs():
77
  with gr.TabItem("instrument prompt") as tab1:
78
 
 
88
  search_btn = gr.Button("search", variant="primary")
89
  stop_btn = gr.Button("stop and output")
90
  output_midi_seq = gr.Textbox()
 
91
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
92
  output_midi = gr.File(label="output midi", file_types=[".mid"])
93
  run_event = search_btn.click(run, [search_prompt],
94
+ [output_midi_seq, output_midi, output_audio])
95
+ stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio], cancels=run_event, queue=False)
96
  app.queue(1).launch(server_port=opt.port, share=opt.share, inbrowser=True)