AlienKevin commited on
Commit
e9e9b11
1 Parent(s): 758aaaa

Map location to device when loading model

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -12,7 +12,7 @@ device = torch.device("cpu")
12
  model = WhisperAudioClassifier().to(device)
13
 
14
  # Load the state dict
15
- state_dict = torch.load(f"whisper-small-encoder-bisyllabic-jyutping/checkpoints/model_epoch_1_step_1800.pth")
16
 
17
  # Load the state dict into the model
18
  model.load_state_dict(state_dict)
 
12
  model = WhisperAudioClassifier().to(device)
13
 
14
  # Load the state dict
15
+ state_dict = torch.load(f"whisper-small-encoder-bisyllabic-jyutping/checkpoints/model_epoch_1_step_1800.pth", map_location=device)
16
 
17
  # Load the state dict into the model
18
  model.load_state_dict(state_dict)