Update model_helper.py
Browse files- model_helper.py +1 -1
model_helper.py
CHANGED
@@ -116,7 +116,7 @@ def load_model_checkpoint(args=None, device='cpu'):
|
|
116 |
eval_subtask_key=args.eval_subtask_key,
|
117 |
write_output_dir=dir_info["lightning_dir"] if args.write_model_output or args.test_octave_shift else None
|
118 |
).to(device)
|
119 |
-
checkpoint = torch.load(dir_info["last_ckpt_path"], map_location=device)
|
120 |
state_dict = checkpoint['state_dict']
|
121 |
new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
|
122 |
model.load_state_dict(new_state_dict, strict=False)
|
|
|
116 |
eval_subtask_key=args.eval_subtask_key,
|
117 |
write_output_dir=dir_info["lightning_dir"] if args.write_model_output or args.test_octave_shift else None
|
118 |
).to(device)
|
119 |
+
checkpoint = torch.load(dir_info["last_ckpt_path"], map_location=device, weights_only=False)
|
120 |
state_dict = checkpoint['state_dict']
|
121 |
new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
|
122 |
model.load_state_dict(new_state_dict, strict=False)
|