zhzluke96 commited on
Commit
db7cbc6
1 Parent(s): ebc4336
modules/Enhancer/ResembleEnhance.py CHANGED
@@ -37,7 +37,7 @@ class ResembleEnhance:
37
  enhancer = Enhancer(hparams)
38
  state_dict = torch.load(
39
  Path(MODELS_DIR) / "resemble-enhance" / "mp_rank_00_model_states.pt",
40
- map_location="cpu",
41
  )["module"]
42
  enhancer.load_state_dict(state_dict)
43
  enhancer.to(self.device).eval()
 
37
  enhancer = Enhancer(hparams)
38
  state_dict = torch.load(
39
  Path(MODELS_DIR) / "resemble-enhance" / "mp_rank_00_model_states.pt",
40
+ map_location=self.device,
41
  )["module"]
42
  enhancer.load_state_dict(state_dict)
43
  enhancer.to(self.device).eval()