Spaces:
Running
on
Zero
Running
on
Zero
asigalov61
commited on
Commit
•
25f28a8
1
Parent(s):
9ce2c1e
Update app.py
Browse files
app.py
CHANGED
@@ -23,7 +23,7 @@ in_space = os.getenv("SYSTEM") == "spaces"
|
|
23 |
# =================================================================================================
|
24 |
|
25 |
@spaces.GPU
|
26 |
-
def GenerateMIDI(num_tok, idrums, iinstr):
|
27 |
print('=' * 70)
|
28 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
29 |
start_time = time.time()
|
@@ -126,6 +126,8 @@ def GenerateMIDI(num_tok, idrums, iinstr):
|
|
126 |
with torch.inference_mode():
|
127 |
out = model.module.generate(inp,
|
128 |
1,
|
|
|
|
|
129 |
temperature=0.9,
|
130 |
return_prime=False,
|
131 |
verbose=False)
|
@@ -207,12 +209,13 @@ if __name__ == "__main__":
|
|
207 |
value="Piano", label="Lead Instrument Controls", info="Desired lead instrument")
|
208 |
input_drums = gr.Checkbox(label="Add Drums", value=False, info="Add drums to the composition")
|
209 |
input_num_tokens = gr.Slider(16, 1024, value=512, label="Number of Tokens", info="Number of tokens to generate")
|
|
|
210 |
|
211 |
run_btn = gr.Button("generate", variant="primary")
|
212 |
-
|
213 |
-
output_plot = gr.Plot(label='output plot')
|
214 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
|
|
215 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
216 |
-
run_event = run_btn.click(GenerateMIDI, [input_num_tokens, input_drums, input_instrument],
|
217 |
[output_plot, output_midi, output_audio])
|
218 |
app.queue().launch()
|
|
|
23 |
# =================================================================================================
|
24 |
|
25 |
@spaces.GPU
|
26 |
+
def GenerateMIDI(num_tok, idrums, iinstr, input_top_k_value):
|
27 |
print('=' * 70)
|
28 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
29 |
start_time = time.time()
|
|
|
126 |
with torch.inference_mode():
|
127 |
out = model.module.generate(inp,
|
128 |
1,
|
129 |
+
filter_logits_fn=top_k,
|
130 |
+
filter_kwargs={'k': input_top_k_value},
|
131 |
temperature=0.9,
|
132 |
return_prime=False,
|
133 |
verbose=False)
|
|
|
209 |
value="Piano", label="Lead Instrument Controls", info="Desired lead instrument")
|
210 |
input_drums = gr.Checkbox(label="Add Drums", value=False, info="Add drums to the composition")
|
211 |
input_num_tokens = gr.Slider(16, 1024, value=512, label="Number of Tokens", info="Number of tokens to generate")
|
212 |
+
input_top_k_value = gr.Slider(1, 100, value=15, label="Model sampling top_k value")
|
213 |
|
214 |
run_btn = gr.Button("generate", variant="primary")
|
215 |
+
|
|
|
216 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
217 |
+
output_plot = gr.Plot(label='output plot')
|
218 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
219 |
+
run_event = run_btn.click(GenerateMIDI, [input_num_tokens, input_drums, input_instrument, input_top_k_value],
|
220 |
[output_plot, output_midi, output_audio])
|
221 |
app.queue().launch()
|