Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import argparse
|
2 |
import glob
|
3 |
import json
|
4 |
import os.path
|
@@ -11,6 +10,7 @@ import torch
|
|
11 |
import torch.nn.functional as F
|
12 |
|
13 |
import gradio as gr
|
|
|
14 |
|
15 |
from x_transformer import *
|
16 |
import tqdm
|
@@ -24,7 +24,7 @@ in_space = os.getenv("SYSTEM") == "spaces"
|
|
24 |
|
25 |
# =================================================================================================
|
26 |
|
27 |
-
@
|
28 |
def GenerateMIDI(num_tok, idrums, iinstr):
|
29 |
print('=' * 70)
|
30 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
@@ -83,6 +83,38 @@ def GenerateMIDI(num_tok, idrums, iinstr):
|
|
83 |
|
84 |
yield output, None, None, [create_msg("visualizer_clear", None)]
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
outy = start_tokens
|
87 |
|
88 |
ctime = 0
|
@@ -201,42 +233,6 @@ if __name__ == "__main__":
|
|
201 |
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
202 |
print('=' * 70)
|
203 |
|
204 |
-
parser = argparse.ArgumentParser()
|
205 |
-
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
|
206 |
-
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
207 |
-
opt = parser.parse_args()
|
208 |
-
|
209 |
-
print('Loading model...')
|
210 |
-
|
211 |
-
SEQ_LEN = 2048
|
212 |
-
|
213 |
-
# instantiate the model
|
214 |
-
|
215 |
-
model = TransformerWrapper(
|
216 |
-
num_tokens=3088,
|
217 |
-
max_seq_len=SEQ_LEN,
|
218 |
-
attn_layers=Decoder(dim=1024, depth=16, heads=8)
|
219 |
-
)
|
220 |
-
|
221 |
-
model = AutoregressiveWrapper(model)
|
222 |
-
|
223 |
-
model = torch.nn.DataParallel(model)
|
224 |
-
|
225 |
-
model.cpu()
|
226 |
-
print('=' * 70)
|
227 |
-
|
228 |
-
print('Loading model checkpoint...')
|
229 |
-
|
230 |
-
model.load_state_dict(
|
231 |
-
torch.load('Allegro_Music_Transformer_Tiny_Trained_Model_80000_steps_0.9457_loss_0.7443_acc.pth',
|
232 |
-
map_location='cpu'))
|
233 |
-
print('=' * 70)
|
234 |
-
|
235 |
-
model.eval()
|
236 |
-
|
237 |
-
print('Done!')
|
238 |
-
print('=' * 70)
|
239 |
-
|
240 |
load_javascript()
|
241 |
app = gr.Blocks()
|
242 |
with app:
|
@@ -267,4 +263,4 @@ if __name__ == "__main__":
|
|
267 |
[output_midi_seq, output_midi, output_audio, js_msg])
|
268 |
interrupt_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg],
|
269 |
cancels=run_event, queue=False)
|
270 |
-
app.queue(
|
|
|
|
|
1 |
import glob
|
2 |
import json
|
3 |
import os.path
|
|
|
10 |
import torch.nn.functional as F
|
11 |
|
12 |
import gradio as gr
|
13 |
+
import spaces
|
14 |
|
15 |
from x_transformer import *
|
16 |
import tqdm
|
|
|
24 |
|
25 |
# =================================================================================================
|
26 |
|
27 |
+
@spaces.GPU
|
28 |
def GenerateMIDI(num_tok, idrums, iinstr):
|
29 |
print('=' * 70)
|
30 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
|
|
83 |
|
84 |
yield output, None, None, [create_msg("visualizer_clear", None)]
|
85 |
|
86 |
+
|
87 |
+
print('Loading model...')
|
88 |
+
|
89 |
+
SEQ_LEN = 2048
|
90 |
+
|
91 |
+
# instantiate the model
|
92 |
+
|
93 |
+
model = TransformerWrapper(
|
94 |
+
num_tokens=3088,
|
95 |
+
max_seq_len=SEQ_LEN,
|
96 |
+
attn_layers=Decoder(dim=1024, depth=16, heads=8)
|
97 |
+
)
|
98 |
+
|
99 |
+
model = AutoregressiveWrapper(model)
|
100 |
+
|
101 |
+
model = torch.nn.DataParallel(model)
|
102 |
+
|
103 |
+
model.cpu()
|
104 |
+
print('=' * 70)
|
105 |
+
|
106 |
+
print('Loading model checkpoint...')
|
107 |
+
|
108 |
+
model.load_state_dict(
|
109 |
+
torch.load('Allegro_Music_Transformer_Tiny_Trained_Model_80000_steps_0.9457_loss_0.7443_acc.pth',
|
110 |
+
map_location='cpu'))
|
111 |
+
print('=' * 70)
|
112 |
+
|
113 |
+
model.eval()
|
114 |
+
|
115 |
+
print('Done!')
|
116 |
+
print('=' * 70)
|
117 |
+
|
118 |
outy = start_tokens
|
119 |
|
120 |
ctime = 0
|
|
|
233 |
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
234 |
print('=' * 70)
|
235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
load_javascript()
|
237 |
app = gr.Blocks()
|
238 |
with app:
|
|
|
263 |
[output_midi_seq, output_midi, output_audio, js_msg])
|
264 |
interrupt_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg],
|
265 |
cancels=run_event, queue=False)
|
266 |
+
app.queue().launch()
|