keshavbhandari commited on
Commit
d63e721
1 Parent(s): 918dd43

added spaces.gpu

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -31,7 +31,7 @@ def save_wav(filepath):
31
  return wav_filepath
32
 
33
 
34
- def generate_midi(caption, temperature=0.9, max_len=3000):
35
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
36
  artifact_folder = 'artifacts'
37
 
@@ -60,10 +60,10 @@ def generate_midi(caption, temperature=0.9, max_len=3000):
60
  generated_midi = r_tokenizer.decode(output_list)
61
  generated_midi.dump_midi("output.mid")
62
 
63
- # @spaces.GPU(duration=120)
64
- def gradio_generate(prompt, temperature):
65
  # Generate midi
66
- generate_midi(prompt, temperature)
67
 
68
  # Convert midi to wav
69
  filename = "output.mid"
@@ -88,6 +88,7 @@ Generate midi music using Text2midi by providing a text prompt.
88
  input_text = gr.Textbox(lines=2, label="Prompt")
89
  output_audio = gr.Audio(label="Generated Music", type="filepath")
90
  temperature = gr.Slider(minimum=0.5, maximum=1.2, value=1.0, step=0.1, label="Temperature", interactive=True)
 
91
 
92
  # CSS styling for the Duplicate button
93
  css = '''
@@ -102,7 +103,7 @@ css = '''
102
  # Gradio interface
103
  gr_interface = gr.Interface(
104
  fn=gradio_generate,
105
- inputs=[input_text, temperature],
106
  outputs=[output_audio],
107
  description=description_text,
108
  allow_flagging=False,
 
31
  return wav_filepath
32
 
33
 
34
+ def generate_midi(caption, temperature=0.9, max_len=1000):
35
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
36
  artifact_folder = 'artifacts'
37
 
 
60
  generated_midi = r_tokenizer.decode(output_list)
61
  generated_midi.dump_midi("output.mid")
62
 
63
+ @spaces.GPU
64
+ def gradio_generate(prompt, temperature, max_length):
65
  # Generate midi
66
+ generate_midi(prompt, temperature, max_length)
67
 
68
  # Convert midi to wav
69
  filename = "output.mid"
 
88
  input_text = gr.Textbox(lines=2, label="Prompt")
89
  output_audio = gr.Audio(label="Generated Music", type="filepath")
90
  temperature = gr.Slider(minimum=0.5, maximum=1.2, value=1.0, step=0.1, label="Temperature", interactive=True)
91
+ max_length = gr.Number(default=1000, label="Max Length", min=100, max=2000, step=100)
92
 
93
  # CSS styling for the Duplicate button
94
  css = '''
 
103
  # Gradio interface
104
  gr_interface = gr.Interface(
105
  fn=gradio_generate,
106
+ inputs=[input_text, temperature, max_length],
107
  outputs=[output_audio],
108
  description=description_text,
109
  allow_flagging=False,