CSH-1220 commited on
Commit
117486a
Β·
1 Parent(s): 55f08a9

File updates regarding memory-saving

Browse files
APadapter/ap_adapter/attention_processor.py CHANGED
@@ -309,7 +309,7 @@ class IPAttnProcessor2_0(torch.nn.Module):
309
  the weight scale of image prompt.
310
  """
311
 
312
- def __init__(self, hidden_size, name, cross_attention_dim=None, num_tokens=4, scale=1.0, do_copy = False):
313
  super().__init__()
314
 
315
  if not hasattr(F, "scaled_dot_product_attention"):
@@ -320,10 +320,12 @@ class IPAttnProcessor2_0(torch.nn.Module):
320
  self.hidden_size = hidden_size
321
  self.cross_attention_dim = cross_attention_dim
322
  self.num_tokens = num_tokens
 
323
  self.scale = scale
324
  self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
325
  self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
326
  self.name = name
 
327
  # Below is for copying the weight of the original weight to the \
328
  if do_copy:
329
  print("do copy")
@@ -451,7 +453,8 @@ class IPAttnProcessor2_0(torch.nn.Module):
451
  ip_hidden_states = ip_hidden_states.to(query.dtype)
452
  # print("hidden_states",hidden_states)
453
  # print("ip_hidden_states",ip_hidden_states)
454
- hidden_states = hidden_states + self.scale * ip_hidden_states
 
455
  # print("ip_hidden_states",ip_hidden_states.shape)
456
  # linear proj
457
  hidden_states = attn.to_out[0](hidden_states)
 
309
  the weight scale of image prompt.
310
  """
311
 
312
+ def __init__(self, hidden_size, name, flag = 'normal', cross_attention_dim=None, num_tokens=4, text_scale = 1.0 , scale=1.0, do_copy = False):
313
  super().__init__()
314
 
315
  if not hasattr(F, "scaled_dot_product_attention"):
 
320
  self.hidden_size = hidden_size
321
  self.cross_attention_dim = cross_attention_dim
322
  self.num_tokens = num_tokens
323
+ self.text_scale = text_scale
324
  self.scale = scale
325
  self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
326
  self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
327
  self.name = name
328
+ self.flag = flag
329
  # Below is for copying the weight of the original weight to the \
330
  if do_copy:
331
  print("do copy")
 
453
  ip_hidden_states = ip_hidden_states.to(query.dtype)
454
  # print("hidden_states",hidden_states)
455
  # print("ip_hidden_states",ip_hidden_states)
456
+ # print(f'{self.flag} Hello, I pass here!')
457
+ hidden_states = self.text_scale * hidden_states + self.scale * ip_hidden_states
458
  # print("ip_hidden_states",ip_hidden_states.shape)
459
  # linear proj
460
  hidden_states = attn.to_out[0](hidden_states)
app.py CHANGED
@@ -1,26 +1,39 @@
1
  import os
 
2
  import torch
 
 
3
  import torchaudio
4
  import numpy as np
5
  import gradio as gr
6
  from pipeline.morph_pipeline_successed_ver1 import AudioLDM2MorphPipeline
 
7
  # Initialize AudioLDM2 Pipeline
8
- pipeline = AudioLDM2MorphPipeline.from_pretrained("cvssp/audioldm2-large", torch_dtype=torch.float32)
 
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  pipeline.to(device)
11
 
12
- # Audio morphing function
13
- def morph_audio(audio_file1, audio_file2, prompt1, prompt2, negative_prompt1="Low quality", negative_prompt2="Low quality"):
14
  save_lora_dir = "output"
 
 
15
  os.makedirs(save_lora_dir, exist_ok=True)
16
 
17
  # Load audio and compute duration
18
- waveform, sample_rate = torchaudio.load(audio_file1)
19
- duration = waveform.shape[1] / sample_rate
20
- duration = int(duration)
 
 
 
 
21
 
22
  # Perform morphing using the pipeline
23
  _ = pipeline(
 
24
  audio_file=audio_file1,
25
  audio_file2=audio_file2,
26
  audio_length_in_s=duration,
@@ -33,13 +46,13 @@ def morph_audio(audio_file1, audio_file2, prompt1, prompt2, negative_prompt1="Lo
33
  save_lora_dir=save_lora_dir,
34
  use_adain=True,
35
  use_reschedule=False,
36
- num_inference_steps=50,
37
  lamd=0.6,
38
  output_path=save_lora_dir,
39
  num_frames=5,
40
  fix_lora=None,
41
  use_lora=True,
42
- lora_steps=50,
43
  noisy_latent_with_lora=True,
44
  morphing_with_lora=True,
45
  use_morph_prompt=True,
@@ -51,32 +64,125 @@ def morph_audio(audio_file1, audio_file2, prompt1, prompt2, negative_prompt1="Lo
51
  [os.path.join(save_lora_dir, file) for file in os.listdir(save_lora_dir) if file.endswith(".wav")],
52
  key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
53
  )
 
 
 
 
 
54
  return output_paths
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  # Gradio interface function
58
- def interface(audio1, audio2, prompt1, prompt2):
59
- output_paths = morph_audio(audio1, audio2, prompt1, prompt2)
60
  return output_paths
61
 
62
  # Gradio Interface
63
- demo = gr.Interface(
64
- fn=interface,
65
- inputs=[
66
- gr.Audio(label="Upload Audio File 1", type="filepath"),
67
- gr.Audio(label="Upload Audio File 2", type="filepath"),
68
- # gr.Slider(4, 6, step=1, label="Octave 1"),
69
- gr.Textbox(label="Prompt for Audio File 1"),
70
- gr.Textbox(label="Prompt for Audio File 2")
71
- ],
72
- outputs=[
73
- gr.Audio(label="Morphing audio 1"),
74
- gr.Audio(label="Morphing audio 2"),
75
- gr.Audio(label="Morphing audio 3"),
76
- gr.Audio(label="Morphing audio 4"),
77
- gr.Audio(label="Morphing audio 5"),
78
- ],
79
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  if __name__ == "__main__":
82
- demo.launch()
 
1
  import os
2
+ import gc
3
  import torch
4
+ import shutil
5
+ import atexit
6
  import torchaudio
7
  import numpy as np
8
  import gradio as gr
9
  from pipeline.morph_pipeline_successed_ver1 import AudioLDM2MorphPipeline
10
+ os.environ["CUDA_VISIBLE_DEVICES"] = "6"
11
  # Initialize AudioLDM2 Pipeline
12
+ torch.cuda.set_device(0)
13
+ dtype = torch.float32
14
+ pipeline = AudioLDM2MorphPipeline.from_pretrained("cvssp/audioldm2-large", torch_dtype=dtype)
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  pipeline.to(device)
17
 
18
+
19
+ def morph_audio(audio_file1, audio_file2, num_inference_steps, prompt1='', prompt2='', negative_prompt1="Low quality", negative_prompt2="Low quality"):
20
  save_lora_dir = "output"
21
+ if os.path.exists(save_lora_dir):
22
+ shutil.rmtree(save_lora_dir)
23
  os.makedirs(save_lora_dir, exist_ok=True)
24
 
25
  # Load audio and compute duration
26
+ waveform1, sample_rate1 = torchaudio.load(audio_file1)
27
+ duration1 = waveform1.shape[1] / sample_rate1
28
+ waveform2, sample_rate2 = torchaudio.load(audio_file2)
29
+ duration2 = waveform2.shape[1] / sample_rate2
30
+
31
+ # Compare durations and take the shorter one
32
+ duration = int(min(duration1, duration2))
33
 
34
  # Perform morphing using the pipeline
35
  _ = pipeline(
36
+ dtype = dtype,
37
  audio_file=audio_file1,
38
  audio_file2=audio_file2,
39
  audio_length_in_s=duration,
 
46
  save_lora_dir=save_lora_dir,
47
  use_adain=True,
48
  use_reschedule=False,
49
+ num_inference_steps=num_inference_steps,
50
  lamd=0.6,
51
  output_path=save_lora_dir,
52
  num_frames=5,
53
  fix_lora=None,
54
  use_lora=True,
55
+ lora_steps=2,
56
  noisy_latent_with_lora=True,
57
  morphing_with_lora=True,
58
  use_morph_prompt=True,
 
64
  [os.path.join(save_lora_dir, file) for file in os.listdir(save_lora_dir) if file.endswith(".wav")],
65
  key=lambda x: int(os.path.splitext(os.path.basename(x))[0])
66
  )
67
+ del waveform1, waveform2, _
68
+ torch.cuda.empty_cache()
69
+ gc.collect()
70
+
71
+
72
  return output_paths
73
 
74
+ def morph_audio_with_morphing_factor(audio_file1, audio_file2, alpha,num_inference_steps, prompt1='', prompt2='', negative_prompt1="Low quality", negative_prompt2="Low quality"):
75
+ save_lora_dir = "output"
76
+ if os.path.exists(save_lora_dir):
77
+ shutil.rmtree(save_lora_dir)
78
+ os.makedirs(save_lora_dir, exist_ok=True)
79
+
80
+ # Load audio and compute duration
81
+ waveform1, sample_rate1 = torchaudio.load(audio_file1)
82
+ duration1 = waveform1.shape[1] / sample_rate1
83
+ waveform2, sample_rate2 = torchaudio.load(audio_file2)
84
+ duration2 = waveform2.shape[1] / sample_rate2
85
+
86
+ # Compare durations and take the shorter one
87
+ duration = int(min(duration1, duration2))
88
+ try:
89
+ # Perform morphing using the pipeline
90
+ _ = pipeline(
91
+ dtype = dtype,
92
+ morphing_factor = alpha,
93
+ audio_file=audio_file1,
94
+ audio_file2=audio_file2,
95
+ audio_length_in_s=duration,
96
+ time_pooling=2,
97
+ freq_pooling=2,
98
+ prompt_1=prompt1,
99
+ prompt_2=prompt2,
100
+ negative_prompt_1=negative_prompt1,
101
+ negative_prompt_2=negative_prompt2,
102
+ save_lora_dir=save_lora_dir,
103
+ use_adain=True,
104
+ use_reschedule=False,
105
+ num_inference_steps=num_inference_steps,
106
+ lamd=0.6,
107
+ output_path=save_lora_dir,
108
+ num_frames=5,
109
+ fix_lora=None,
110
+ use_lora=True,
111
+ lora_steps=2,
112
+ noisy_latent_with_lora=True,
113
+ morphing_with_lora=True,
114
+ use_morph_prompt=True,
115
+ guidance_scale=7.5,
116
+ )
117
+ output_paths = os.path.join(save_lora_dir, 'interpolated.wav')
118
+
119
+ except RuntimeError as e:
120
+ if "CUDA out of memory" in str(e):
121
+ print("CUDA out of memory. Releasing unused memory...")
122
+ torch.cuda.empty_cache()
123
+ gc.collect()
124
+ raise e
125
+ # # Collect the output file paths
126
+ # del waveform1, waveform2, _
127
+ # torch.cuda.empty_cache()
128
+ # gc.collect()
129
+
130
+ return output_paths
131
+
132
+ def cleanup_output_dir():
133
+ save_lora_dir = "output"
134
+ if os.path.exists(save_lora_dir):
135
+ shutil.rmtree(save_lora_dir)
136
+ print(f"Cleaned up directory: {save_lora_dir}")
137
+ atexit.register(cleanup_output_dir)
138
 
139
  # Gradio interface function
140
+ def interface(audio1, audio2, alpha, num_inference_steps):
141
+ output_paths = morph_audio_with_morphing_factor(audio1, audio2, alpha, num_inference_steps)
142
  return output_paths
143
 
144
  # Gradio Interface
145
+ # demo = gr.Interface(
146
+ # fn=interface,
147
+ # inputs=[
148
+ # gr.Audio(label="Upload Audio File 1", type="filepath"),
149
+ # gr.Audio(label="Upload Audio File 2", type="filepath"),
150
+ # gr.Slider(0, 1, step=0.01, label="Interpolation Alpha"),
151
+ # gr.Slider(10, 50, step=1, label="Inference Steps"),
152
+ # # gr.Textbox(label="Prompt for Audio File 1"),
153
+ # # gr.Textbox(label="Prompt for Audio File 2"),
154
+ # ],
155
+ # outputs=gr.Audio(label="Interpolated Audio")
156
+ # )
157
+
158
+
159
+ with gr.Blocks() as demo:
160
+ with gr.Tab("Sound Morphing with fixed frames."):
161
+ gr.Markdown("### Upload two audio files for morphing")
162
+ with gr.Row():
163
+ audio1 = gr.Audio(label="Upload Audio File 1", type="filepath")
164
+ audio2 = gr.Audio(label="Upload Audio File 2", type="filepath")
165
+ num_inference_steps = gr.Slider(10, 50, step=1, label="Inference Steps", value=50)
166
+ outputs = [
167
+ gr.Audio(label="Morphing audio 1"),
168
+ gr.Audio(label="Morphing audio 2"),
169
+ gr.Audio(label="Morphing audio 3"),
170
+ gr.Audio(label="Morphing audio 4"),
171
+ gr.Audio(label="Morphing audio 5"),
172
+ ]
173
+ submit_btn1 = gr.Button("Submit")
174
+ submit_btn1.click(morph_audio, inputs=[audio1, audio2, num_inference_steps], outputs=outputs)
175
+
176
+ with gr.Tab("Sound Morphing with specified morphing factor."):
177
+ gr.Markdown("### Upload two audio files for morphing")
178
+ with gr.Row():
179
+ audio1 = gr.Audio(label="Upload Audio File 1", type="filepath")
180
+ audio2 = gr.Audio(label="Upload Audio File 2", type="filepath")
181
+ alpha = gr.Slider(0, 1, step=0.01, label="Interpolation Alpha")
182
+ num_inference_steps = gr.Slider(10, 50, step=1, label="Inference Steps", value=50)
183
+ outputs=gr.Audio(label="Interpolated Audio")
184
+ submit_btn2 = gr.Button("Submit")
185
+ submit_btn2.click(morph_audio_with_morphing_factor, inputs=[audio1, audio2, alpha, num_inference_steps], outputs=outputs)
186
 
187
  if __name__ == "__main__":
188
+ demo.launch(share=True)
pipeline/morph_pipeline_successed_ver1.py CHANGED
@@ -227,6 +227,10 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
227
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
228
  self.aud1_dict = dict()
229
  self.aud2_dict = dict()
 
 
 
 
230
 
231
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
232
  def enable_vae_slicing(self):
@@ -928,8 +932,11 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
928
  DEVICE = torch.device(
929
  "cuda") if torch.cuda.is_available() else torch.device("cpu")
930
  mel_spect_tensor = wav_to_mel(audio_path, duration=audio_length_in_s).unsqueeze(0)
931
- output_path = audio_path.replace('.wav', '_fbank.png')
932
- visualize_mel_spectrogram(mel_spect_tensor, output_path)
 
 
 
933
  mel_spect_tensor = mel_spect_tensor.to(next(self.vae.parameters()).dtype)
934
  # print(f'mel_spect_tensor dtype: {mel_spect_tensor.dtype}')
935
  # print(f'self.vae dtype: {next(self.vae.parameters()).dtype}')
@@ -1062,6 +1069,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1062
  def __call__(
1063
  self,
1064
  dtype,
 
1065
  audio_file = None,
1066
  audio_file2 = None,
1067
  ap_scale = 1.0,
@@ -1118,11 +1126,11 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1118
  cross_attention_dim = cross[layer_num % 8]
1119
  layer_num += 1
1120
  if cross_attention_dim == 768:
1121
- attn_procs[name].scale = IPAttnProcessor2_0(
1122
  hidden_size=hidden_size,
1123
  name=name,
1124
  cross_attention_dim=cross_attention_dim,
1125
- text_scale=100,
1126
  scale=ap_scale,
1127
  num_tokens=8,
1128
  do_copy=False
@@ -1141,7 +1149,6 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1141
  processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].half())
1142
  processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].half())
1143
  self.unet.set_attn_processor(attn_procs)
1144
- self.pipeline_trained = self.init_trained_pipeline(ap_adapter_path, device, dtype, ap_scale, text_ap_scale)
1145
 
1146
  # 1. Pre-check
1147
  height, original_waveform_length = self.pre_check(audio_length_in_s, prompt_1, callback_steps, negative_prompt_1)
@@ -1200,7 +1207,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1200
  # ------- For the first audio file -------
1201
  original_processor = list(self.unet.attn_processors.values())[0]
1202
  if noisy_latent_with_lora:
1203
- self.unet = load_lora(self.unet, lora_1, lora_2, 0)
1204
  # We directly use the latent representation of the audio file for VAE's decoder as the 1st ground truth
1205
  audio_latent = self.aud2latent(audio_file, audio_length_in_s).to(device)
1206
  # aud_noise_1 is the noisy latent representation of the audio file 1
@@ -1211,7 +1218,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1211
 
1212
  # ------- For the second audio file -------
1213
  if noisy_latent_with_lora:
1214
- self.unet = load_lora(self.unet, lora_1, lora_2, 1)
1215
  # We directly use the latent representation of the audio file for VAE's decoder as the 1st ground truth
1216
  audio_latent = self.aud2latent(audio_file2, audio_length_in_s)
1217
  # aud_noise_2 is the noisy latent representation of the audio file 2
@@ -1220,12 +1227,13 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1220
  self.unet.set_attn_processor(original_processor)
1221
  # After reconstructed the audio file 1, we set the original processor back
1222
  original_processor = list(self.unet.attn_processors.values())[0]
 
1223
  def morph(alpha_list, desc):
1224
  audios = []
1225
  # if attn_beta is not None:
1226
  if self.use_lora:
1227
  self.unet = load_lora(
1228
- self.unet, lora_1, lora_2, 0 if fix_lora is None else fix_lora)
1229
  attn_processor_dict = {}
1230
  for k in self.unet.attn_processors.keys():
1231
  # print(k)
@@ -1266,7 +1274,7 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1266
  scipy.io.wavfile.write(file_path, rate=16000, data=first_audio)
1267
  if self.use_lora:
1268
  self.unet = load_lora(
1269
- self.unet, lora_1, lora_2, 1 if fix_lora is None else fix_lora)
1270
  attn_processor_dict = {}
1271
  for k in self.unet.attn_processors.keys():
1272
  if do_replace_attn(k):
@@ -1304,12 +1312,24 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1304
  scipy.io.wavfile.write(file_path, rate=16000, data=last_audio)
1305
 
1306
  self.unet.set_attn_processor(original_processor)
1307
-
1308
- for i in tqdm(range(1, num_frames - 1), desc=desc):
1309
- alpha = alpha_list[i]
 
 
 
 
 
 
 
 
 
 
 
 
1310
  if self.use_lora:
1311
  self.unet = load_lora(
1312
- self.unet, lora_1, lora_2, alpha if fix_lora is None else fix_lora)
1313
 
1314
  attn_processor_dict = {}
1315
  for k in self.unet.attn_processors.keys():
@@ -1338,25 +1358,70 @@ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
1338
  prompt_embeds_2,
1339
  attention_mask_2,
1340
  generated_prompt_embeds_2,
1341
- alpha_list[i],
1342
  original_processor,
1343
  attn_processor_dict,
1344
  use_morph_prompt,
1345
  morphing_with_lora
1346
  )
1347
- file_path = os.path.join(self.output_path, f"{i:02d}.wav")
1348
  scipy.io.wavfile.write(file_path, rate=16000, data=audio)
1349
  self.unet.set_attn_processor(original_processor)
1350
  audios.append(audio)
1351
- audios = [first_audio] + audios + [last_audio]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1352
  return audios
 
1353
  with torch.no_grad():
1354
  if self.use_reschedule:
1355
  alpha_scheduler = AlphaScheduler()
1356
  alpha_list = list(torch.linspace(0, 1, num_frames))
1357
  audios_pt = morph(alpha_list, "Sampling...")
1358
  audios_pt = [torch.tensor(aud).unsqueeze(0)
1359
- for aud in audios_pt]
1360
  alpha_scheduler.from_imgs(audios_pt)
1361
  alpha_list = alpha_scheduler.get_list()
1362
  audios = morph(alpha_list, "Reschedule...")
 
227
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
228
  self.aud1_dict = dict()
229
  self.aud2_dict = dict()
230
+ ap_adapter_path = 'pytorch_model.bin'
231
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
232
+ dtype = next(self.vae.parameters()).dtype
233
+ self.pipeline_trained = self.init_trained_pipeline(ap_adapter_path, device, dtype, ap_scale=1.0, text_ap_scale=1.0)
234
 
235
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
236
  def enable_vae_slicing(self):
 
932
  DEVICE = torch.device(
933
  "cuda") if torch.cuda.is_available() else torch.device("cpu")
934
  mel_spect_tensor = wav_to_mel(audio_path, duration=audio_length_in_s).unsqueeze(0)
935
+ # if audio_path.endswith('.wav'):
936
+ # output_path = audio_path.replace('.wav', '_fbank.png')
937
+ # elif audio_path.endswith('.mp3'):
938
+ # output_path = audio_path.replace('.mp3', '_fbank.png')
939
+ # visualize_mel_spectrogram(mel_spect_tensor, output_path)
940
  mel_spect_tensor = mel_spect_tensor.to(next(self.vae.parameters()).dtype)
941
  # print(f'mel_spect_tensor dtype: {mel_spect_tensor.dtype}')
942
  # print(f'self.vae dtype: {next(self.vae.parameters()).dtype}')
 
1069
  def __call__(
1070
  self,
1071
  dtype,
1072
+ morphing_factor = None,
1073
  audio_file = None,
1074
  audio_file2 = None,
1075
  ap_scale = 1.0,
 
1126
  cross_attention_dim = cross[layer_num % 8]
1127
  layer_num += 1
1128
  if cross_attention_dim == 768:
1129
+ attn_procs[name] = IPAttnProcessor2_0(
1130
  hidden_size=hidden_size,
1131
  name=name,
1132
  cross_attention_dim=cross_attention_dim,
1133
+ text_scale=text_ap_scale,
1134
  scale=ap_scale,
1135
  num_tokens=8,
1136
  do_copy=False
 
1149
  processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].half())
1150
  processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].half())
1151
  self.unet.set_attn_processor(attn_procs)
 
1152
 
1153
  # 1. Pre-check
1154
  height, original_waveform_length = self.pre_check(audio_length_in_s, prompt_1, callback_steps, negative_prompt_1)
 
1207
  # ------- For the first audio file -------
1208
  original_processor = list(self.unet.attn_processors.values())[0]
1209
  if noisy_latent_with_lora:
1210
+ self.unet = load_lora(self.unet, lora_1, lora_2, 0, dtype=dtype)
1211
  # We directly use the latent representation of the audio file for VAE's decoder as the 1st ground truth
1212
  audio_latent = self.aud2latent(audio_file, audio_length_in_s).to(device)
1213
  # aud_noise_1 is the noisy latent representation of the audio file 1
 
1218
 
1219
  # ------- For the second audio file -------
1220
  if noisy_latent_with_lora:
1221
+ self.unet = load_lora(self.unet, lora_1, lora_2, 1, dtype=dtype)
1222
  # We directly use the latent representation of the audio file for VAE's decoder as the 1st ground truth
1223
  audio_latent = self.aud2latent(audio_file2, audio_length_in_s)
1224
  # aud_noise_2 is the noisy latent representation of the audio file 2
 
1227
  self.unet.set_attn_processor(original_processor)
1228
  # After reconstructed the audio file 1, we set the original processor back
1229
  original_processor = list(self.unet.attn_processors.values())[0]
1230
+
1231
  def morph(alpha_list, desc):
1232
  audios = []
1233
  # if attn_beta is not None:
1234
  if self.use_lora:
1235
  self.unet = load_lora(
1236
+ self.unet, lora_1, lora_2, 0 if fix_lora is None else fix_lora, dtype=dtype)
1237
  attn_processor_dict = {}
1238
  for k in self.unet.attn_processors.keys():
1239
  # print(k)
 
1274
  scipy.io.wavfile.write(file_path, rate=16000, data=first_audio)
1275
  if self.use_lora:
1276
  self.unet = load_lora(
1277
+ self.unet, lora_1, lora_2, 1 if fix_lora is None else fix_lora, dtype=dtype)
1278
  attn_processor_dict = {}
1279
  for k in self.unet.attn_processors.keys():
1280
  if do_replace_attn(k):
 
1312
  scipy.io.wavfile.write(file_path, rate=16000, data=last_audio)
1313
 
1314
  self.unet.set_attn_processor(original_processor)
1315
+
1316
+ if morphing_factor is not None:
1317
+ alpha = morphing_factor
1318
+ if alpha==0:
1319
+ file_path = os.path.join(self.output_path, f"interpolated.wav")
1320
+ scipy.io.wavfile.write(file_path, rate=16000, data=first_audio)
1321
+ self.unet.set_attn_processor(original_processor)
1322
+ audios.append(audio)
1323
+ return audios
1324
+ elif alpha==1:
1325
+ file_path = os.path.join(self.output_path, f"interpolated.wav")
1326
+ scipy.io.wavfile.write(file_path, rate=16000, data=last_audio)
1327
+ self.unet.set_attn_processor(original_processor)
1328
+ audios.append(audio)
1329
+ return audios
1330
  if self.use_lora:
1331
  self.unet = load_lora(
1332
+ self.unet, lora_1, lora_2, alpha if fix_lora is None else fix_lora, dtype=dtype)
1333
 
1334
  attn_processor_dict = {}
1335
  for k in self.unet.attn_processors.keys():
 
1358
  prompt_embeds_2,
1359
  attention_mask_2,
1360
  generated_prompt_embeds_2,
1361
+ alpha,
1362
  original_processor,
1363
  attn_processor_dict,
1364
  use_morph_prompt,
1365
  morphing_with_lora
1366
  )
1367
+ file_path = os.path.join(self.output_path, f"interpolated.wav")
1368
  scipy.io.wavfile.write(file_path, rate=16000, data=audio)
1369
  self.unet.set_attn_processor(original_processor)
1370
  audios.append(audio)
1371
+ else:
1372
+ for i in tqdm(range(1, num_frames - 1), desc=desc):
1373
+ alpha = alpha_list[i]
1374
+ if self.use_lora:
1375
+ self.unet = load_lora(
1376
+ self.unet, lora_1, lora_2, alpha if fix_lora is None else fix_lora, dtype=dtype)
1377
+
1378
+ attn_processor_dict = {}
1379
+ for k in self.unet.attn_processors.keys():
1380
+ if do_replace_attn(k):
1381
+ if self.use_lora:
1382
+ attn_processor_dict[k] = LoadProcessor(
1383
+ self.unet.attn_processors[k], k, self.aud1_dict, self.aud2_dict, alpha, attn_beta, lamd)
1384
+ else:
1385
+ attn_processor_dict[k] = LoadProcessor(
1386
+ original_processor, k, self.aud1_dict, self.aud2_dict, alpha, attn_beta, lamd)
1387
+ else:
1388
+ attn_processor_dict[k] = self.unet.attn_processors[k]
1389
+ audio, latents = self.cal_latent(
1390
+ audio_length_in_s,
1391
+ time_pooling,
1392
+ freq_pooling,
1393
+ num_inference_steps,
1394
+ guidance_scale,
1395
+ aud_noise_1,
1396
+ aud_noise_2,
1397
+ prompt_1,
1398
+ prompt_2,
1399
+ prompt_embeds_1,
1400
+ attention_mask_1,
1401
+ generated_prompt_embeds_1,
1402
+ prompt_embeds_2,
1403
+ attention_mask_2,
1404
+ generated_prompt_embeds_2,
1405
+ alpha_list[i],
1406
+ original_processor,
1407
+ attn_processor_dict,
1408
+ use_morph_prompt,
1409
+ morphing_with_lora
1410
+ )
1411
+ file_path = os.path.join(self.output_path, f"{i:02d}.wav")
1412
+ scipy.io.wavfile.write(file_path, rate=16000, data=audio)
1413
+ self.unet.set_attn_processor(original_processor)
1414
+ audios.append(audio)
1415
+ audios = [first_audio] + audios + [last_audio]
1416
  return audios
1417
+
1418
  with torch.no_grad():
1419
  if self.use_reschedule:
1420
  alpha_scheduler = AlphaScheduler()
1421
  alpha_list = list(torch.linspace(0, 1, num_frames))
1422
  audios_pt = morph(alpha_list, "Sampling...")
1423
  audios_pt = [torch.tensor(aud).unsqueeze(0)
1424
+ for aud in audios_pt]
1425
  alpha_scheduler.from_imgs(audios_pt)
1426
  alpha_list = alpha_scheduler.get_list()
1427
  audios = morph(alpha_list, "Reschedule...")
utils/lora_utils_successed_ver1.py CHANGED
@@ -664,6 +664,8 @@ def train_lora(audio_path ,dtype ,time_pooling ,freq_pooling ,prompt, negative_p
664
  weight_name=weight_name,
665
  safe_serialization=safe_serialization
666
  )
 
 
667
 
668
  def load_lora(unet, lora_0, lora_1, alpha, dtype):
669
  attn_procs = unet.attn_processors
 
664
  weight_name=weight_name,
665
  safe_serialization=safe_serialization
666
  )
667
+
668
+ del loss_history, unet_lora_layers, unet, vae, text_encoder, text_encoder_2, GPT2, projection_model, vocoder, noise_scheduler, optimizer, lr_scheduler, model
669
 
670
  def load_lora(unet, lora_0, lora_1, alpha, dtype):
671
  attn_procs = unet.attn_processors