sweetcocoa commited on
Commit
ceb92ad
1 Parent(s): 604b1f5

cuda support

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import os
3
  from transformer_wrapper import TransformerWrapper
@@ -18,6 +19,9 @@ def model_load():
18
  )
19
  model_id = "dpipqxiy"
20
  wrapper.eval()
 
 
 
21
  return wrapper, model_id, config
22
 
23
 
@@ -37,7 +41,7 @@ def inference(file_up, composer):
37
  show_plot=False,
38
  save_midi=True,
39
  save_mix=True,
40
- midi_path="output.mid"
41
  )
42
 
43
  return mix_path, midi_path
 
1
+ import torch
2
  import gradio as gr
3
  import os
4
  from transformer_wrapper import TransformerWrapper
 
19
  )
20
  model_id = "dpipqxiy"
21
  wrapper.eval()
22
+ if torch.cuda.is_available():
23
+ wrapper = wrapper.cuda()
24
+
25
  return wrapper, model_id, config
26
 
27
 
 
41
  show_plot=False,
42
  save_midi=True,
43
  save_mix=True,
44
+ midi_path="output.mid",
45
  )
46
 
47
  return mix_path, midi_path