AisakaMikoto commited on
Commit
5961cfc
·
verified ·
1 Parent(s): 33171a6

Update model/CLAPSep.py

Browse files
Files changed (1) hide show
  1. model/CLAPSep.py +2 -6
model/CLAPSep.py CHANGED
@@ -69,17 +69,13 @@ class CLAPSep(nn.Module):
69
  pred = self.istft(mag_y * cos_y, mag_y * sin_y, length=length)
70
  return pred
71
 
72
- def inference_from_data(self, mixed, pos_prompt, neg_prompt):
73
  self.eval()
74
  real, imag = self.stft(mixed)
75
  mag, cos, sin = magphase(real, imag)
76
  self.features.append(mag)
77
  with torch.no_grad():
78
- embed_pos, embed_neg = torch.chunk(self.clap_model.get_text_embedding(pos_prompt + neg_prompt,
79
- use_tensor=True), dim=0, chunks=2)
80
- embed_pos = torch.zeros_like(embed_pos) if pos_prompt == '' else embed_pos
81
- embed_neg = torch.zeros_like(embed_neg) if neg_prompt == '' else embed_neg
82
- embed = torch.concat([embed_pos, embed_neg], dim=-1)
83
  self.audio_branch({"waveform": self.resampler(mixed)})
84
  mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed)
85
  pred = self.wav_reconstruct(mask, mag, cos, sin, length=mixed.size(-1))
 
69
  pred = self.istft(mag_y * cos_y, mag_y * sin_y, length=length)
70
  return pred
71
 
72
+ def inference_from_data(self, mixed, embed_pos, embed_neg):
73
  self.eval()
74
  real, imag = self.stft(mixed)
75
  mag, cos, sin = magphase(real, imag)
76
  self.features.append(mag)
77
  with torch.no_grad():
78
+ embed = torch.nn.functional.normalize(torch.concat([embed_pos, embed_neg], dim=-1), dim=-1)
 
 
 
 
79
  self.audio_branch({"waveform": self.resampler(mixed)})
80
  mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed)
81
  pred = self.wav_reconstruct(mask, mag, cos, sin, length=mixed.size(-1))