asigalov61 commited on
Commit
399d36d
·
verified ·
1 Parent(s): cee37ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -93,21 +93,21 @@ def GenerateMIDI(num_tok, idrums, iinstr):
93
  model = TransformerWrapper(
94
  num_tokens=3088,
95
  max_seq_len=SEQ_LEN,
96
- attn_layers=Decoder(dim=1024, depth=16, heads=8)
97
  )
98
 
99
  model = AutoregressiveWrapper(model)
100
 
101
  model = torch.nn.DataParallel(model)
102
 
103
- model.cpu()
104
  print('=' * 70)
105
 
106
  print('Loading model checkpoint...')
107
 
108
  model.load_state_dict(
109
  torch.load('Allegro_Music_Transformer_Tiny_Trained_Model_80000_steps_0.9457_loss_0.7443_acc.pth',
110
- map_location='cpu'))
111
  print('=' * 70)
112
 
113
  model.eval()
@@ -125,13 +125,14 @@ def GenerateMIDI(num_tok, idrums, iinstr):
125
 
126
  for i in range(max(1, min(512, num_tok))):
127
 
128
- inp = torch.LongTensor([outy]).cpu()
129
 
130
- out = model.module.generate(inp,
131
- 1,
132
- temperature=0.9,
133
- return_prime=False,
134
- verbose=False)
 
135
 
136
  out0 = out[0].tolist()
137
  outy.extend(out0)
@@ -253,7 +254,7 @@ if __name__ == "__main__":
253
  run_btn = gr.Button("generate", variant="primary")
254
  interrupt_btn = gr.Button("interrupt")
255
 
256
- output_midi_seq = gr.JSON()
257
  output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
258
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
259
  output_midi = gr.File(label="output midi", file_types=[".mid"])
 
93
  model = TransformerWrapper(
94
  num_tokens=3088,
95
  max_seq_len=SEQ_LEN,
96
+ attn_layers=Decoder(dim=1024, depth=16, heads=8, attn_flash=True)
97
  )
98
 
99
  model = AutoregressiveWrapper(model)
100
 
101
  model = torch.nn.DataParallel(model)
102
 
103
+ model.cuda()
104
  print('=' * 70)
105
 
106
  print('Loading model checkpoint...')
107
 
108
  model.load_state_dict(
109
  torch.load('Allegro_Music_Transformer_Tiny_Trained_Model_80000_steps_0.9457_loss_0.7443_acc.pth',
110
+ map_location='cuda'))
111
  print('=' * 70)
112
 
113
  model.eval()
 
125
 
126
  for i in range(max(1, min(512, num_tok))):
127
 
128
+ inp = torch.LongTensor([outy]).cuda()
129
 
130
+ with torch.amp.autocast(device_type='cuda', dtype=torch.float16)
131
+ out = model.module.generate(inp,
132
+ 1,
133
+ temperature=0.9,
134
+ return_prime=False,
135
+ verbose=False)
136
 
137
  out0 = out[0].tolist()
138
  outy.extend(out0)
 
254
  run_btn = gr.Button("generate", variant="primary")
255
  interrupt_btn = gr.Button("interrupt")
256
 
257
+ output_midi_seq = gr.HTML()
258
  output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
259
  output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
260
  output_midi = gr.File(label="output midi", file_types=[".mid"])