Spaces:
Sleeping
Sleeping
update loss
Browse files- inference.py +1 -1
- modules/loss.py +91 -24
inference.py
CHANGED
@@ -99,6 +99,7 @@ class MasteringStyleTransfer:
|
|
99 |
target = ito_config['clap_text_prompt']
|
100 |
print(f'ito_config clap_distance_fn: {ito_config["clap_distance_fn"]}')
|
101 |
total_loss = self.clap_loss(output_audio, target, self.args.sample_rate, distance_fn=ito_config['clap_distance_fn'])
|
|
|
102 |
|
103 |
if total_loss < min_loss:
|
104 |
min_loss = total_loss.item()
|
@@ -243,7 +244,6 @@ class MasteringStyleTransfer:
|
|
243 |
if isinstance(param_value, torch.Tensor):
|
244 |
param_value = param_value.item()
|
245 |
|
246 |
-
print(f"fx name: {fx_name} param_name: {param_name}")
|
247 |
if fx_name in param_mapper and param_name in param_mapper[fx_name]:
|
248 |
friendly_name, unit, min_val, max_val = param_mapper[fx_name][param_name]
|
249 |
if unit=='%':
|
|
|
99 |
target = ito_config['clap_text_prompt']
|
100 |
print(f'ito_config clap_distance_fn: {ito_config["clap_distance_fn"]}')
|
101 |
total_loss = self.clap_loss(output_audio, target, self.args.sample_rate, distance_fn=ito_config['clap_distance_fn'])
|
102 |
+
print(f'total_loss: {total_loss}')
|
103 |
|
104 |
if total_loss < min_loss:
|
105 |
min_loss = total_loss.item()
|
|
|
244 |
if isinstance(param_value, torch.Tensor):
|
245 |
param_value = param_value.item()
|
246 |
|
|
|
247 |
if fx_name in param_mapper and param_name in param_mapper[fx_name]:
|
248 |
friendly_name, unit, min_val, max_val = param_mapper[fx_name][param_name]
|
249 |
if unit=='%':
|
modules/loss.py
CHANGED
@@ -185,25 +185,35 @@ class CLAPFeatureLoss(nn.Module):
|
|
185 |
self.target_sample_rate = 48000 # CLAP expects 48kHz audio
|
186 |
self.model = laion_clap.CLAP_Module(enable_fusion=False)
|
187 |
self.model.load_ckpt() # download the default pretrained checkpoint
|
|
|
|
|
|
|
|
|
188 |
|
189 |
-
def forward(self, input_audio, target, sample_rate, distance_fn='
|
190 |
# Process input audio
|
191 |
-
|
|
|
|
|
|
|
|
|
192 |
|
193 |
# Process target (audio or text)
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
|
|
|
|
200 |
|
201 |
# Compute loss using the specified distance function
|
202 |
loss = self.compute_distance(input_embed, target_embed, distance_fn)
|
203 |
|
204 |
return loss
|
205 |
|
206 |
-
def
|
207 |
# Ensure input is in the correct shape (N, C, T)
|
208 |
if audio.dim() == 2:
|
209 |
audio = audio.unsqueeze(1)
|
@@ -219,19 +229,7 @@ class CLAPFeatureLoss(nn.Module):
|
|
219 |
# Quantize audio data
|
220 |
audio = self.quantize(audio)
|
221 |
|
222 |
-
|
223 |
-
with torch.no_grad():
|
224 |
-
embed = self.model.get_audio_embedding_from_data(x=audio, use_tensor=True)
|
225 |
-
return embed
|
226 |
-
|
227 |
-
def process_text(self, text):
|
228 |
-
# Get CLAP embeddings for text
|
229 |
-
# ensure input is a list of strings
|
230 |
-
if not isinstance(text, list):
|
231 |
-
text = [text]
|
232 |
-
with torch.no_grad():
|
233 |
-
embed = self.model.get_text_embedding(text, use_tensor=True)
|
234 |
-
return embed
|
235 |
|
236 |
def compute_distance(self, x, y, distance_fn):
|
237 |
if distance_fn == 'mse':
|
@@ -249,11 +247,80 @@ class CLAPFeatureLoss(nn.Module):
|
|
249 |
audio = (audio * 32767.0).to(torch.int16).to(torch.float32) / 32767.0
|
250 |
return audio
|
251 |
|
252 |
-
def resample(self, audio,
|
253 |
resampler = torchaudio.transforms.Resample(
|
254 |
-
orig_freq=
|
255 |
).to(audio.device)
|
256 |
return resampler(audio)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
|
258 |
|
259 |
"""
|
|
|
185 |
self.target_sample_rate = 48000 # CLAP expects 48kHz audio
|
186 |
self.model = laion_clap.CLAP_Module(enable_fusion=False)
|
187 |
self.model.load_ckpt() # download the default pretrained checkpoint
|
188 |
+
|
189 |
+
# Freeze the CLAP model parameters
|
190 |
+
for param in self.model.parameters():
|
191 |
+
param.requires_grad = False
|
192 |
|
193 |
+
def forward(self, input_audio, target, sample_rate, distance_fn='mse'):
|
194 |
# Process input audio
|
195 |
+
with torch.no_grad():
|
196 |
+
input_audio = self.preprocess_audio(input_audio, sample_rate)
|
197 |
+
|
198 |
+
with torch.enable_grad():
|
199 |
+
input_embed = self.model.get_audio_embedding_from_data(x=input_audio, use_tensor=True)
|
200 |
|
201 |
# Process target (audio or text)
|
202 |
+
with torch.no_grad():
|
203 |
+
if isinstance(target, torch.Tensor):
|
204 |
+
target_audio = self.preprocess_audio(target, sample_rate)
|
205 |
+
target_embed = self.model.get_audio_embedding_from_data(x=target_audio, use_tensor=True)
|
206 |
+
elif isinstance(target, str) or (isinstance(target, list) and isinstance(target[0], str)):
|
207 |
+
target_embed = self.model.get_text_embedding(target, use_tensor=True)
|
208 |
+
else:
|
209 |
+
raise ValueError("Target must be either audio tensor or text (string or list of strings)")
|
210 |
|
211 |
# Compute loss using the specified distance function
|
212 |
loss = self.compute_distance(input_embed, target_embed, distance_fn)
|
213 |
|
214 |
return loss
|
215 |
|
216 |
+
def preprocess_audio(self, audio, sample_rate):
|
217 |
# Ensure input is in the correct shape (N, C, T)
|
218 |
if audio.dim() == 2:
|
219 |
audio = audio.unsqueeze(1)
|
|
|
229 |
# Quantize audio data
|
230 |
audio = self.quantize(audio)
|
231 |
|
232 |
+
return audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
def compute_distance(self, x, y, distance_fn):
|
235 |
if distance_fn == 'mse':
|
|
|
247 |
audio = (audio * 32767.0).to(torch.int16).to(torch.float32) / 32767.0
|
248 |
return audio
|
249 |
|
250 |
+
def resample(self, audio, orig_sample_rate):
|
251 |
resampler = torchaudio.transforms.Resample(
|
252 |
+
orig_freq=orig_sample_rate, new_freq=self.target_sample_rate
|
253 |
).to(audio.device)
|
254 |
return resampler(audio)
|
255 |
+
|
256 |
+
# def forward(self, input_audio, target, sample_rate, distance_fn='cosine'):
|
257 |
+
# # Process input audio
|
258 |
+
# input_embed = self.process_audio(input_audio, sample_rate)
|
259 |
+
|
260 |
+
# # Process target (audio or text)
|
261 |
+
# if isinstance(target, torch.Tensor):
|
262 |
+
# target_embed = self.process_audio(target, sample_rate)
|
263 |
+
# elif isinstance(target, str) or (isinstance(target, list) and isinstance(target[0], str)):
|
264 |
+
# target_embed = self.process_text(target)
|
265 |
+
# else:
|
266 |
+
# raise ValueError("Target must be either audio tensor or text (string or list of strings)")
|
267 |
+
|
268 |
+
# # Compute loss using the specified distance function
|
269 |
+
# loss = self.compute_distance(input_embed, target_embed, distance_fn)
|
270 |
+
|
271 |
+
# return loss
|
272 |
+
|
273 |
+
# def process_audio(self, audio, sample_rate):
|
274 |
+
# # Ensure input is in the correct shape (N, C, T)
|
275 |
+
# if audio.dim() == 2:
|
276 |
+
# audio = audio.unsqueeze(1)
|
277 |
+
|
278 |
+
# # Convert to mono if stereo
|
279 |
+
# if audio.shape[1] > 1:
|
280 |
+
# audio = audio.mean(dim=1, keepdim=True)
|
281 |
+
|
282 |
+
# # Resample if necessary
|
283 |
+
# if sample_rate != self.target_sample_rate:
|
284 |
+
# audio = self.resample(audio, sample_rate)
|
285 |
+
|
286 |
+
# # Quantize audio data
|
287 |
+
# audio = self.quantize(audio)
|
288 |
+
|
289 |
+
# # Get CLAP embeddings
|
290 |
+
# with torch.no_grad():
|
291 |
+
# embed = self.model.get_audio_embedding_from_data(x=audio, use_tensor=True)
|
292 |
+
# return embed
|
293 |
+
|
294 |
+
# def process_text(self, text):
|
295 |
+
# # Get CLAP embeddings for text
|
296 |
+
# # ensure input is a list of strings
|
297 |
+
# if not isinstance(text, list):
|
298 |
+
# text = [text]
|
299 |
+
# with torch.no_grad():
|
300 |
+
# embed = self.model.get_text_embedding(text, use_tensor=True)
|
301 |
+
# return embed
|
302 |
+
|
303 |
+
# def compute_distance(self, x, y, distance_fn):
|
304 |
+
# if distance_fn == 'mse':
|
305 |
+
# return F.mse_loss(x, y)
|
306 |
+
# elif distance_fn == 'l1':
|
307 |
+
# return F.l1_loss(x, y)
|
308 |
+
# elif distance_fn == 'cosine':
|
309 |
+
# return 1 - F.cosine_similarity(x, y).mean()
|
310 |
+
# else:
|
311 |
+
# raise ValueError(f"Unsupported distance function: {distance_fn}")
|
312 |
+
|
313 |
+
# def quantize(self, audio):
|
314 |
+
# audio = audio.squeeze(1) # Remove channel dimension
|
315 |
+
# audio = torch.clamp(audio, -1.0, 1.0)
|
316 |
+
# audio = (audio * 32767.0).to(torch.int16).to(torch.float32) / 32767.0
|
317 |
+
# return audio
|
318 |
+
|
319 |
+
# def resample(self, audio, input_sample_rate):
|
320 |
+
# resampler = torchaudio.transforms.Resample(
|
321 |
+
# orig_freq=input_sample_rate, new_freq=self.target_sample_rate
|
322 |
+
# ).to(audio.device)
|
323 |
+
# return resampler(audio)
|
324 |
|
325 |
|
326 |
"""
|