asigalov61 commited on
Commit
a914076
1 Parent(s): a173e60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -6
app.py CHANGED
@@ -23,7 +23,7 @@ in_space = os.getenv("SYSTEM") == "spaces"
23
  # =================================================================================================
24
 
25
  @spaces.GPU
26
- def GenerateMIDI(num_tok, idrums, iinstr, input_top_k_ratio):
27
  print('=' * 70)
28
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
29
  start_time = time.time()
@@ -32,7 +32,6 @@ def GenerateMIDI(num_tok, idrums, iinstr, input_top_k_ratio):
32
  print('Req num tok:', num_tok)
33
  print('Req instr:', iinstr)
34
  print('Drums:', idrums)
35
- print('top_k:', input_top_k_ratio)
36
  print('-' * 70)
37
 
38
  if idrums:
@@ -127,8 +126,6 @@ def GenerateMIDI(num_tok, idrums, iinstr, input_top_k_ratio):
127
  with torch.inference_mode():
128
  out = model.module.generate(inp,
129
  1,
130
- filter_logits_fn = top_k,
131
- filter_thres = input_top_k_ratio,
132
  temperature=0.9,
133
  return_prime=False,
134
  verbose=False)
@@ -210,13 +207,12 @@ if __name__ == "__main__":
210
  value="Piano", label="Lead Instrument Controls", info="Desired lead instrument")
211
  input_drums = gr.Checkbox(label="Add Drums", value=False, info="Add drums to the composition")
212
  input_num_tokens = gr.Slider(16, 1024, value=512, label="Number of Tokens", info="Number of tokens to generate")
213
- input_top_k_ratio = gr.Slider(0.1, 1, value=0.95, step=0.01, label="Model sampling top_k ratio")
214
 
215
  run_btn = gr.Button("generate", variant="primary")
216
 
217
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
218
  output_plot = gr.Plot(label='output plot')
219
  output_midi = gr.File(label="output midi", file_types=[".mid"])
220
- run_event = run_btn.click(GenerateMIDI, [input_num_tokens, input_drums, input_instrument, input_top_k_ratio],
221
  [output_plot, output_midi, output_audio])
222
  app.queue().launch()
 
23
  # =================================================================================================
24
 
25
  @spaces.GPU
26
+ def GenerateMIDI(num_tok, idrums, iinstr):
27
  print('=' * 70)
28
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
29
  start_time = time.time()
 
32
  print('Req num tok:', num_tok)
33
  print('Req instr:', iinstr)
34
  print('Drums:', idrums)
 
35
  print('-' * 70)
36
 
37
  if idrums:
 
126
  with torch.inference_mode():
127
  out = model.module.generate(inp,
128
  1,
 
 
129
  temperature=0.9,
130
  return_prime=False,
131
  verbose=False)
 
207
  value="Piano", label="Lead Instrument Controls", info="Desired lead instrument")
208
  input_drums = gr.Checkbox(label="Add Drums", value=False, info="Add drums to the composition")
209
  input_num_tokens = gr.Slider(16, 1024, value=512, label="Number of Tokens", info="Number of tokens to generate")
 
210
 
211
  run_btn = gr.Button("generate", variant="primary")
212
 
213
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
214
  output_plot = gr.Plot(label='output plot')
215
  output_midi = gr.File(label="output midi", file_types=[".mid"])
216
+ run_event = run_btn.click(GenerateMIDI, [input_num_tokens, input_drums, input_instrument],
217
  [output_plot, output_midi, output_audio])
218
  app.queue().launch()