Spaces:
Running
Running
Update model/CLAPSep.py
Browse files- model/CLAPSep.py +2 -6
model/CLAPSep.py
CHANGED
@@ -69,17 +69,13 @@ class CLAPSep(nn.Module):
|
|
69 |
pred = self.istft(mag_y * cos_y, mag_y * sin_y, length=length)
|
70 |
return pred
|
71 |
|
72 |
-
def inference_from_data(self, mixed,
|
73 |
self.eval()
|
74 |
real, imag = self.stft(mixed)
|
75 |
mag, cos, sin = magphase(real, imag)
|
76 |
self.features.append(mag)
|
77 |
with torch.no_grad():
|
78 |
-
|
79 |
-
use_tensor=True), dim=0, chunks=2)
|
80 |
-
embed_pos = torch.zeros_like(embed_pos) if pos_prompt == '' else embed_pos
|
81 |
-
embed_neg = torch.zeros_like(embed_neg) if neg_prompt == '' else embed_neg
|
82 |
-
embed = torch.concat([embed_pos, embed_neg], dim=-1)
|
83 |
self.audio_branch({"waveform": self.resampler(mixed)})
|
84 |
mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed)
|
85 |
pred = self.wav_reconstruct(mask, mag, cos, sin, length=mixed.size(-1))
|
|
|
69 |
pred = self.istft(mag_y * cos_y, mag_y * sin_y, length=length)
|
70 |
return pred
|
71 |
|
72 |
+
def inference_from_data(self, mixed, embed_pos, embed_neg):
|
73 |
self.eval()
|
74 |
real, imag = self.stft(mixed)
|
75 |
mag, cos, sin = magphase(real, imag)
|
76 |
self.features.append(mag)
|
77 |
with torch.no_grad():
|
78 |
+
embed = torch.nn.functional.normalize(torch.concat([embed_pos, embed_neg], dim=-1), dim=-1)
|
|
|
|
|
|
|
|
|
79 |
self.audio_branch({"waveform": self.resampler(mixed)})
|
80 |
mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed)
|
81 |
pred = self.wav_reconstruct(mask, mag, cos, sin, length=mixed.size(-1))
|