crypto-code commited on
Commit
e94a16b
1 Parent(s): f32caf3

Update llama/m2ugen.py

Browse files
Files changed (1) hide show
  1. llama/m2ugen.py +58 -58
llama/m2ugen.py CHANGED
@@ -43,10 +43,10 @@ class M2UGen(nn.Module):
43
  # https://huggingface.co/m-a-p/MERT-v1-330M
44
  # And set the mert_path argument to directory with the model files
45
  print(f'Initialize MERT...')
46
- self.mert_model = AutoModel.from_pretrained(self.args.mert_path, trust_remote_code=True) # .to(self.device)
47
  self.mert_processor = Wav2Vec2FeatureExtractor.from_pretrained(self.args.mert_path, trust_remote_code=True)
48
- self.mu_mert_agg = nn.Conv1d(in_channels=25, out_channels=1, kernel_size=1)
49
- self.mu_mert_proj = nn.Linear(1024, 4096)
50
 
51
  if legacy_bridge:
52
  bridge_norm_layer = nn.LayerNorm
@@ -57,20 +57,20 @@ class M2UGen(nn.Module):
57
 
58
  self.feature_scaler = 1
59
 
60
- self.mu_mert_norm_1 = bridge_norm_layer(4096)
61
- self.mu_mert_f1_1 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
62
- self.mu_mert_f2_1 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias)
63
- self.mu_mert_f3_1 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
64
 
65
- self.mu_mert_norm_2 = bridge_norm_layer(4096)
66
- self.mu_mert_f1_2 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
67
- self.mu_mert_f2_2 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias)
68
- self.mu_mert_f3_2 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
69
 
70
- self.mu_mert_norm_3 = bridge_norm_layer(4096)
71
- self.mu_mert_f1_3 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
72
- self.mu_mert_f2_3 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias)
73
- self.mu_mert_f3_3 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
74
  print(f'MERT initialized...')
75
 
76
  # 2. ViT Encoder
@@ -78,26 +78,26 @@ class M2UGen(nn.Module):
78
  # https://huggingface.co/google/vit-base-patch16-224-in21k
79
  # And set the vit_path argument to directory with the model files
80
  print(f'Initialize ViT...')
81
- self.vit_model = ViTModel.from_pretrained(self.args.vit_path) # .to(self.device)
82
  self.vit_model.eval()
83
  self.vit_processor = ViTImageProcessor.from_pretrained(self.args.vit_path, do_rescale=False)
84
- self.iu_vit_agg = nn.Conv1d(in_channels=197, out_channels=1, kernel_size=1)
85
- self.iu_vit_proj = nn.Linear(768, 4096)
86
-
87
- self.iu_vit_norm_1 = bridge_norm_layer(4096)
88
- self.iu_vit_f1_1 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
89
- self.iu_vit_f2_1 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias)
90
- self.iu_vit_f3_1 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
91
-
92
- self.iu_vit_norm_2 = bridge_norm_layer(4096)
93
- self.iu_vit_f1_2 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
94
- self.iu_vit_f2_2 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias)
95
- self.iu_vit_f3_2 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
96
-
97
- self.iu_vit_norm_3 = bridge_norm_layer(4096)
98
- self.iu_vit_f1_3 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
99
- self.iu_vit_f2_3 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias)
100
- self.iu_vit_f3_3 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
101
  print(f'ViT initialized...')
102
 
103
  # 3. ViViT Encoder
@@ -105,26 +105,26 @@ class M2UGen(nn.Module):
105
  # https://huggingface.co/google/vivit-b-16x2-kinetics400
106
  # And set the vivit_path argument to directory with the model files
107
  print(f'Initialize ViViT...')
108
- self.vivit_model = VivitModel.from_pretrained(self.args.vivit_path) # .to(self.device)
109
  self.vivit_model.eval()
110
  self.vivit_processor = VivitImageProcessor.from_pretrained(self.args.vivit_path)
111
- self.iu_vivit_agg = nn.Conv1d(in_channels=3137, out_channels=1, kernel_size=1)
112
- self.iu_vivit_proj = nn.Linear(768, 4096)
113
-
114
- self.iu_vivit_norm_1 = bridge_norm_layer(4096)
115
- self.iu_vivit_f1_1 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
116
- self.iu_vivit_f2_1 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias)
117
- self.iu_vivit_f3_1 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
118
-
119
- self.iu_vivit_norm_2 = bridge_norm_layer(4096)
120
- self.iu_vivit_f1_2 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
121
- self.iu_vivit_f2_2 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias)
122
- self.iu_vivit_f3_2 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
123
-
124
- self.iu_vivit_norm_3 = bridge_norm_layer(4096)
125
- self.iu_vivit_f1_3 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
126
- self.iu_vivit_f2_3 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias)
127
- self.iu_vivit_f3_3 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias)
128
  print(f'ViViT initialized...')
129
 
130
  # 4. llama
@@ -152,7 +152,7 @@ class M2UGen(nn.Module):
152
 
153
  if torch.cuda.is_available():
154
  torch.set_default_tensor_type(torch.cuda.HalfTensor)
155
- self.llama = Transformer(self.model_args)
156
  torch.set_default_tensor_type(torch.FloatTensor)
157
 
158
  if load_llama:
@@ -207,7 +207,7 @@ class M2UGen(nn.Module):
207
  # 5. projector
208
  self.output_projector = ProjectionLayer(4096, self.model_args.output_dim_tokens,
209
  num_input_tokens=self.model_args.num_gen_audio_tokens,
210
- num_output_tokens=self.model_args.num_output_tokens)
211
 
212
  # 6. Generator
213
  if self.args.music_decoder.lower() == "audioldm2":
@@ -217,7 +217,7 @@ class M2UGen(nn.Module):
217
  print(f'Initialize AudioLDM2...')
218
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
219
  self.generation_model = AudioLDM2Pipeline.from_pretrained(self.args.music_decoder_path, torch_dtype=dtype)
220
- self.generation_model.to("cuda")
221
  print(f'AudioLDM2 initialized...')
222
  else:
223
  # The model files for MusicGen can be downloaded here in case of network issues:
@@ -225,7 +225,7 @@ class M2UGen(nn.Module):
225
  # And set the music_decoder_path argument to directory with the model files
226
  print(f'Initialize MusicGen...')
227
  self.generation_processor = AutoProcessor.from_pretrained(self.args.music_decoder_path)
228
- self.generation_model = MusicgenForConditionalGeneration.from_pretrained(self.args.music_decoder_path)
229
  self.generation_model.eval()
230
  print(f'MusicGen initialized...')
231
  self.music_decoder = self.args.music_decoder.lower()
@@ -233,7 +233,7 @@ class M2UGen(nn.Module):
233
  # 4. prefix
234
  self.query_layer = 20
235
  self.query_len = 1
236
- self.prefix_query = nn.Embedding(self.query_layer * self.query_len, self.model_args.dim)
237
 
238
  # 5. knn
239
  self.knn = knn
@@ -672,10 +672,10 @@ class M2UGen(nn.Module):
672
 
673
  total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
674
 
675
- tokens = torch.full((bsz, total_len), 0).cuda().long()
676
 
677
  for k, t in enumerate(prompts):
678
- tokens[k, : len(t)] = torch.tensor(t).cuda().long()
679
  input_text_mask = tokens != 0
680
  start_pos = min_prompt_size
681
  prev_pos = 0
 
43
  # https://huggingface.co/m-a-p/MERT-v1-330M
44
  # And set the mert_path argument to directory with the model files
45
  print(f'Initialize MERT...')
46
+ self.mert_model = AutoModel.from_pretrained(self.args.mert_path, trust_remote_code=True).to("cuda:0")
47
  self.mert_processor = Wav2Vec2FeatureExtractor.from_pretrained(self.args.mert_path, trust_remote_code=True)
48
+ self.mu_mert_agg = nn.Conv1d(in_channels=25, out_channels=1, kernel_size=1).to("cuda:0")
49
+ self.mu_mert_proj = nn.Linear(1024, 4096).to("cuda:0")
50
 
51
  if legacy_bridge:
52
  bridge_norm_layer = nn.LayerNorm
 
57
 
58
  self.feature_scaler = 1
59
 
60
+ self.mu_mert_norm_1 = bridge_norm_layer(4096).to("cuda:0")
61
+ self.mu_mert_f1_1 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias).to("cuda:0")
62
+ self.mu_mert_f2_1 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias).to("cuda:0")
63
+ self.mu_mert_f3_1 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias).to("cuda:0")
64
 
65
+ self.mu_mert_norm_2 = bridge_norm_layer(4096).to("cuda:0")
66
+ self.mu_mert_f1_2 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias).to("cuda:0")
67
+ self.mu_mert_f2_2 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias).to("cuda:0")
68
+ self.mu_mert_f3_2 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias).to("cuda:0")
69
 
70
+ self.mu_mert_norm_3 = bridge_norm_layer(4096).to("cuda:0")
71
+ self.mu_mert_f1_3 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias).to("cuda:0")
72
+ self.mu_mert_f2_3 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias).to("cuda:0")
73
+ self.mu_mert_f3_3 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias).to("cuda:0")
74
  print(f'MERT initialized...')
75
 
76
  # 2. ViT Encoder
 
78
  # https://huggingface.co/google/vit-base-patch16-224-in21k
79
  # And set the vit_path argument to directory with the model files
80
  print(f'Initialize ViT...')
81
+ self.vit_model = ViTModel.from_pretrained(self.args.vit_path).to("cuda:0") # .to(self.device)
82
  self.vit_model.eval()
83
  self.vit_processor = ViTImageProcessor.from_pretrained(self.args.vit_path, do_rescale=False)
84
+ self.iu_vit_agg = nn.Conv1d(in_channels=197, out_channels=1, kernel_size=1).to("cuda:0")
85
+ self.iu_vit_proj = nn.Linear(768, 4096).to("cuda:0")
86
+
87
+ self.iu_vit_norm_1 = bridge_norm_layer(4096).to("cuda:0")
88
+ self.iu_vit_f1_1 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias).to("cuda:0")
89
+ self.iu_vit_f2_1 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias).to("cuda:0")
90
+ self.iu_vit_f3_1 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias).to("cuda:0")
91
+
92
+ self.iu_vit_norm_2 = bridge_norm_layer(4096).to("cuda:0")
93
+ self.iu_vit_f1_2 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias).to("cuda:0")
94
+ self.iu_vit_f2_2 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias).to("cuda:0")
95
+ self.iu_vit_f3_2 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias).to("cuda:0")
96
+
97
+ self.iu_vit_norm_3 = bridge_norm_layer(4096).to("cuda:0")
98
+ self.iu_vit_f1_3 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias).to("cuda:0")
99
+ self.iu_vit_f2_3 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias).to("cuda:0")
100
+ self.iu_vit_f3_3 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias).to("cuda:0")
101
  print(f'ViT initialized...')
102
 
103
  # 3. ViViT Encoder
 
105
  # https://huggingface.co/google/vivit-b-16x2-kinetics400
106
  # And set the vivit_path argument to directory with the model files
107
  print(f'Initialize ViViT...')
108
+ self.vivit_model = VivitModel.from_pretrained(self.args.vivit_path).to("cuda:0") # .to(self.device)
109
  self.vivit_model.eval()
110
  self.vivit_processor = VivitImageProcessor.from_pretrained(self.args.vivit_path)
111
+ self.iu_vivit_agg = nn.Conv1d(in_channels=3137, out_channels=1, kernel_size=1).to("cuda:0")
112
+ self.iu_vivit_proj = nn.Linear(768, 4096).to("cuda:0")
113
+
114
+ self.iu_vivit_norm_1 = bridge_norm_layer(4096).to("cuda:0")
115
+ self.iu_vivit_f1_1 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias).to("cuda:0")
116
+ self.iu_vivit_f2_1 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias).to("cuda:0")
117
+ self.iu_vivit_f3_1 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias).to("cuda:0")
118
+
119
+ self.iu_vivit_norm_2 = bridge_norm_layer(4096).to("cuda:0")
120
+ self.iu_vivit_f1_2 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias).to("cuda:0")
121
+ self.iu_vivit_f2_2 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias).to("cuda:0")
122
+ self.iu_vivit_f3_2 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias).to("cuda:0")
123
+
124
+ self.iu_vivit_norm_3 = bridge_norm_layer(4096).to("cuda:0")
125
+ self.iu_vivit_f1_3 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias).to("cuda:0")
126
+ self.iu_vivit_f2_3 = nn.Linear(4096 * self.feature_scaler, 4096, bias=bridge_bias).to("cuda:0")
127
+ self.iu_vivit_f3_3 = nn.Linear(4096, 4096 * self.feature_scaler, bias=bridge_bias).to("cuda:0")
128
  print(f'ViViT initialized...')
129
 
130
  # 4. llama
 
152
 
153
  if torch.cuda.is_available():
154
  torch.set_default_tensor_type(torch.cuda.HalfTensor)
155
+ self.llama = Transformer(self.model_args).to("cuda:1")
156
  torch.set_default_tensor_type(torch.FloatTensor)
157
 
158
  if load_llama:
 
207
  # 5. projector
208
  self.output_projector = ProjectionLayer(4096, self.model_args.output_dim_tokens,
209
  num_input_tokens=self.model_args.num_gen_audio_tokens,
210
+ num_output_tokens=self.model_args.num_output_tokens).to("cuda:1")
211
 
212
  # 6. Generator
213
  if self.args.music_decoder.lower() == "audioldm2":
 
217
  print(f'Initialize AudioLDM2...')
218
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
219
  self.generation_model = AudioLDM2Pipeline.from_pretrained(self.args.music_decoder_path, torch_dtype=dtype)
220
+ self.generation_model.to("cuda:1")
221
  print(f'AudioLDM2 initialized...')
222
  else:
223
  # The model files for MusicGen can be downloaded here in case of network issues:
 
225
  # And set the music_decoder_path argument to directory with the model files
226
  print(f'Initialize MusicGen...')
227
  self.generation_processor = AutoProcessor.from_pretrained(self.args.music_decoder_path)
228
+ self.generation_model = MusicgenForConditionalGeneration.from_pretrained(self.args.music_decoder_path).to("cuda:1")
229
  self.generation_model.eval()
230
  print(f'MusicGen initialized...')
231
  self.music_decoder = self.args.music_decoder.lower()
 
233
  # 4. prefix
234
  self.query_layer = 20
235
  self.query_len = 1
236
+ self.prefix_query = nn.Embedding(self.query_layer * self.query_len, self.model_args.dim).to("cuda:1")
237
 
238
  # 5. knn
239
  self.knn = knn
 
672
 
673
  total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
674
 
675
+ tokens = torch.full((bsz, total_len), 0).to("cuda:1").long()
676
 
677
  for k, t in enumerate(prompts):
678
+ tokens[k, : len(t)] = torch.tensor(t).to("cuda:1").long()
679
  input_text_mask = tokens != 0
680
  start_pos = min_prompt_size
681
  prev_pos = 0