asigalov61 commited on
Commit
ff7362d
·
verified ·
1 Parent(s): 86b7652

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -88,7 +88,7 @@ def GenerateMIDI(num_tok, idrums, iinstr):
88
  model = TransformerWrapper(
89
  num_tokens=3088,
90
  max_seq_len=SEQ_LEN,
91
- attn_layers=Decoder(dim=1024, depth=16, heads=8, attn_flash=True)
92
  )
93
 
94
  model = AutoregressiveWrapper(model)
@@ -101,7 +101,7 @@ def GenerateMIDI(num_tok, idrums, iinstr):
101
  print('Loading model checkpoint...')
102
 
103
  model.load_state_dict(
104
- torch.load('Allegro_Music_Transformer_Tiny_Trained_Model_80000_steps_0.9457_loss_0.7443_acc.pth',
105
  map_location='cuda'))
106
  print('=' * 70)
107
 
 
88
  model = TransformerWrapper(
89
  num_tokens=3088,
90
  max_seq_len=SEQ_LEN,
91
+ attn_layers=Decoder(dim=1024, depth=32, heads=8, attn_flash=True)
92
  )
93
 
94
  model = AutoregressiveWrapper(model)
 
101
  print('Loading model checkpoint...')
102
 
103
  model.load_state_dict(
104
+ torch.load('Allegro_Music_Transformer_Small_Trained_Model_56000_steps_0.9399_loss_0.7374_acc.pth',
105
  map_location='cuda'))
106
  print('=' * 70)
107