Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -88,7 +88,7 @@ def GenerateMIDI(num_tok, idrums, iinstr):
|
|
88 |
model = TransformerWrapper(
|
89 |
num_tokens=3088,
|
90 |
max_seq_len=SEQ_LEN,
|
91 |
-
attn_layers=Decoder(dim=1024, depth=
|
92 |
)
|
93 |
|
94 |
model = AutoregressiveWrapper(model)
|
@@ -101,7 +101,7 @@ def GenerateMIDI(num_tok, idrums, iinstr):
|
|
101 |
print('Loading model checkpoint...')
|
102 |
|
103 |
model.load_state_dict(
|
104 |
-
torch.load('
|
105 |
map_location='cuda'))
|
106 |
print('=' * 70)
|
107 |
|
|
|
88 |
model = TransformerWrapper(
|
89 |
num_tokens=3088,
|
90 |
max_seq_len=SEQ_LEN,
|
91 |
+
attn_layers=Decoder(dim=1024, depth=32, heads=8, attn_flash=True)
|
92 |
)
|
93 |
|
94 |
model = AutoregressiveWrapper(model)
|
|
|
101 |
print('Loading model checkpoint...')
|
102 |
|
103 |
model.load_state_dict(
|
104 |
+
torch.load('Allegro_Music_Transformer_Small_Trained_Model_56000_steps_0.9399_loss_0.7374_acc.pth',
|
105 |
map_location='cuda'))
|
106 |
print('=' * 70)
|
107 |
|