saefro991 commited on
Commit
2d8f24b
1 Parent(s): 3d7f208

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -91,17 +91,18 @@ def transfer(audio):
91
  checkpoint_path=ckpt_path,
92
  config=config,
93
  strict=False
94
- )
95
 
96
  encoder_src = src_model.encoder.to(device)
97
  channelfeats_src = src_model.channelfeats.to(device)
98
  channel_src = src_model.channel.to(device)
99
-
100
- _, enc_hidden_src = encoder_src(
101
- melspec_src.unsqueeze(0).unsqueeze(1).transpose(2, 3).to(device)
102
- )
103
- chfeats_src = channelfeats_src(enc_hidden_src)
104
- wav_transfer = channel_src(wav_tar.unsqueeze(0), chfeats_src)
 
105
  wav_transfer = wav_transfer.cpu().detach().numpy()[0, :]
106
  return sr, wav_transfer
107
 
 
91
  checkpoint_path=ckpt_path,
92
  config=config,
93
  strict=False
94
+ ).eval()
95
 
96
  encoder_src = src_model.encoder.to(device)
97
  channelfeats_src = src_model.channelfeats.to(device)
98
  channel_src = src_model.channel.to(device)
99
+
100
+ with torch.no_grad():
101
+ _, enc_hidden_src = encoder_src(
102
+ melspec_src.unsqueeze(0).unsqueeze(1).transpose(2, 3).to(device)
103
+ )
104
+ chfeats_src = channelfeats_src(enc_hidden_src)
105
+ wav_transfer = channel_src(wav_tar.unsqueeze(0), chfeats_src)
106
  wav_transfer = wav_transfer.cpu().detach().numpy()[0, :]
107
  return sr, wav_transfer
108