Atin Sakkeer Hussain commited on
Commit
4bae2e1
1 Parent(s): 5379b67

Add requirements

Browse files
Files changed (1) hide show
  1. app.py +3 -43
app.py CHANGED
@@ -5,7 +5,6 @@ import mdtex2html
5
  import tempfile
6
  from PIL import Image
7
  import scipy
8
- import argparse
9
 
10
  from llama.m2ugen import M2UGen
11
  import llama
@@ -18,44 +17,9 @@ import av
18
  import subprocess
19
  import librosa
20
 
21
- parser = argparse.ArgumentParser()
22
- parser.add_argument(
23
- "--model", default="./ckpts/checkpoint.pth", type=str,
24
- help="Name of or path to M2UGen pretrained checkpoint",
25
- )
26
- parser.add_argument(
27
- "--llama_type", default="7B", type=str,
28
- help="Type of llama original weight",
29
- )
30
- parser.add_argument(
31
- "--llama_dir", default="/path/to/llama", type=str,
32
- help="Path to LLaMA pretrained checkpoint",
33
- )
34
- parser.add_argument(
35
- "--mert_path", default="m-a-p/MERT-v1-330M", type=str,
36
- help="Path to MERT pretrained checkpoint",
37
- )
38
- parser.add_argument(
39
- "--vit_path", default="m-a-p/MERT-v1-330M", type=str,
40
- help="Path to ViT pretrained checkpoint",
41
- )
42
- parser.add_argument(
43
- "--vivit_path", default="m-a-p/MERT-v1-330M", type=str,
44
- help="Path to ViViT pretrained checkpoint",
45
- )
46
- parser.add_argument(
47
- "--knn_dir", default="./ckpts", type=str,
48
- help="Path to directory with KNN Index",
49
- )
50
- parser.add_argument(
51
- '--music_decoder', default="musicgen", type=str,
52
- help='Decoder to use musicgen/audioldm2')
53
-
54
- parser.add_argument(
55
- '--music_decoder_path', default="facebook/musicgen-medium", type=str,
56
- help='Path to decoder to use musicgen/audioldm2')
57
-
58
- args = parser.parse_args()
59
 
60
  generated_audio_files = []
61
 
@@ -78,10 +42,6 @@ load_result = model.load_state_dict(new_ckpt, strict=False)
78
  assert len(load_result.unexpected_keys) == 0, f"Unexpected keys: {load_result.unexpected_keys}"
79
  model.eval()
80
  model.to("cuda")
81
- #model.generation_model.to("cuda")
82
- #model.mert_model.to("cuda")
83
- #model.vit_model.to("cuda")
84
- #model.vivit_model.to("cuda")
85
 
86
  transform = transforms.Compose(
87
  [transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0) == 1 else x)])
 
5
  import tempfile
6
  from PIL import Image
7
  import scipy
 
8
 
9
  from llama.m2ugen import M2UGen
10
  import llama
 
17
  import subprocess
18
  import librosa
19
 
20
+ args = {"model": "./ckpts/M2UGen/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2",
21
+ "mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400",
22
+ "music_decoder": "musicgen", "music_decoder_path": "facebook/musicgen-medium"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  generated_audio_files = []
25
 
 
42
  assert len(load_result.unexpected_keys) == 0, f"Unexpected keys: {load_result.unexpected_keys}"
43
  model.eval()
44
  model.to("cuda")
 
 
 
 
45
 
46
  transform = transforms.Compose(
47
  [transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0) == 1 else x)])