keshavbhandari commited on
Commit
60e9e83
1 Parent(s): fd955fa

Add model download functionality

Browse files
Files changed (2) hide show
  1. app.py +6 -1
  2. requirements.txt +2 -1
app.py CHANGED
@@ -12,6 +12,11 @@ from transformers import T5Tokenizer
12
  from transformer_model import Transformer
13
  from miditok import REMI, TokenizerConfig
14
  from pathlib import Path
 
 
 
 
 
15
 
16
 
17
  def save_wav(filepath):
@@ -81,7 +86,7 @@ def generate_midi(caption, temperature=0.9, max_len=500):
81
  vocab_size = len(r_tokenizer)
82
  print("Vocab size: ", vocab_size)
83
  model = Transformer(vocab_size, 768, 8, 2048, 18, 1024, False, 8, device=device)
84
- model_path = os.path.join("amaai-lab/text2midi", "pytorch_model.bin")
85
  model.load_state_dict(torch.load(model_path, map_location=device))
86
  model.eval()
87
  tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
 
12
  from transformer_model import Transformer
13
  from miditok import REMI, TokenizerConfig
14
  from pathlib import Path
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ repo_id = "amaai-lab/text2midi"
18
+ # Download the model.bin file
19
+ model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
20
 
21
 
22
  def save_wav(filepath):
 
86
  vocab_size = len(r_tokenizer)
87
  print("Vocab size: ", vocab_size)
88
  model = Transformer(vocab_size, 768, 8, 2048, 18, 1024, False, 8, device=device)
89
+ # model_path = os.path.join("amaai-lab/text2midi", "pytorch_model.bin")
90
  model.load_state_dict(torch.load(model_path, map_location=device))
91
  model.eval()
92
  tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
requirements.txt CHANGED
@@ -5,4 +5,5 @@ transformers
5
  st-moe-pytorch
6
  jsonlines
7
  miditok==3.0.3
8
- sentencepiece
 
 
5
  st-moe-pytorch
6
  jsonlines
7
  miditok==3.0.3
8
+ sentencepiece
9
+ huggingface_hub