mimbres commited on
Commit
e60064d
·
verified ·
1 Parent(s): 6f613c4

Update model_helper.py

Browse files
Files changed (1) hide show
  1. model_helper.py +1 -1
model_helper.py CHANGED
@@ -116,7 +116,7 @@ def load_model_checkpoint(args=None, device='cpu'):
116
  eval_subtask_key=args.eval_subtask_key,
117
  write_output_dir=dir_info["lightning_dir"] if args.write_model_output or args.test_octave_shift else None
118
  ).to(device)
119
- checkpoint = torch.load(dir_info["last_ckpt_path"], map_location=device)
120
  state_dict = checkpoint['state_dict']
121
  new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
122
  model.load_state_dict(new_state_dict, strict=False)
 
116
  eval_subtask_key=args.eval_subtask_key,
117
  write_output_dir=dir_info["lightning_dir"] if args.write_model_output or args.test_octave_shift else None
118
  ).to(device)
119
+ checkpoint = torch.load(dir_info["last_ckpt_path"], map_location=device, weights_only=False)
120
  state_dict = checkpoint['state_dict']
121
  new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
122
  model.load_state_dict(new_state_dict, strict=False)