skytnt commited on
Commit
43a6dd3
1 Parent(s): fc457c0
Files changed (2) hide show
  1. app.py +4 -1
  2. midi_model.py +1 -4
app.py CHANGED
@@ -223,13 +223,16 @@ if __name__ == "__main__":
223
  "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
224
  }
225
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
226
  models = {}
227
  tokenizer = MIDITokenizer()
228
  for name, (repo_id, path) in models_info.items():
229
 
230
  model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
231
  model = MIDIModel(tokenizer).to(device=device)
232
- ckpt = torch.load(model_path)
233
  state_dict = ckpt.get("state_dict", ckpt)
234
  model.load_state_dict(state_dict, strict=False)
235
  model.eval()
 
223
  "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
224
  }
225
  device = "cuda" if torch.cuda.is_available() else "cpu"
226
+ if device=="cuda": # flash attn
227
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
228
+ torch.backends.cuda.enable_flash_sdp(True)
229
  models = {}
230
  tokenizer = MIDITokenizer()
231
  for name, (repo_id, path) in models_info.items():
232
 
233
  model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
234
  model = MIDIModel(tokenizer).to(device=device)
235
+ ckpt = torch.load(model_path, weights_only=True)
236
  state_dict = ckpt.get("state_dict", ckpt)
237
  model.load_state_dict(state_dict, strict=False)
238
  model.eval()
midi_model.py CHANGED
@@ -9,7 +9,7 @@ from midi_tokenizer import MIDITokenizer
9
 
10
 
11
  class MIDIModel(nn.Module):
12
- def __init__(self, tokenizer: MIDITokenizer, n_layer=12, n_head=16, n_embd=1024, n_inner=4096, flash=False,
13
  *args, **kwargs):
14
  super(MIDIModel, self).__init__()
15
  self.tokenizer = tokenizer
@@ -21,9 +21,6 @@ class MIDIModel(nn.Module):
21
  hidden_size=n_embd, num_attention_heads=n_head // 4,
22
  num_hidden_layers=n_layer // 4, intermediate_size=n_inner // 4,
23
  pad_token_id=tokenizer.pad_id, max_position_embeddings=4096))
24
- if flash:
25
- torch.backends.cuda.enable_mem_efficient_sdp(True)
26
- torch.backends.cuda.enable_flash_sdp(True)
27
  self.lm_head = nn.Linear(n_embd, tokenizer.vocab_size, bias=False)
28
  self.device = "cpu"
29
 
 
9
 
10
 
11
  class MIDIModel(nn.Module):
12
+ def __init__(self, tokenizer: MIDITokenizer, n_layer=12, n_head=16, n_embd=1024, n_inner=4096,
13
  *args, **kwargs):
14
  super(MIDIModel, self).__init__()
15
  self.tokenizer = tokenizer
 
21
  hidden_size=n_embd, num_attention_heads=n_head // 4,
22
  num_hidden_layers=n_layer // 4, intermediate_size=n_inner // 4,
23
  pad_token_id=tokenizer.pad_id, max_position_embeddings=4096))
 
 
 
24
  self.lm_head = nn.Linear(n_embd, tokenizer.vocab_size, bias=False)
25
  self.device = "cpu"
26