AlexN commited on
Commit
a19b5e0
1 Parent(s): af94e9a

Update TractionModel.py

Browse files
Files changed (1) hide show
  1. app/TractionModel.py +2 -2
app/TractionModel.py CHANGED
@@ -53,7 +53,7 @@ def create_model():
53
  return model
54
 
55
 
56
- def load_weights(model, path='model.pt'):
57
- checkpoint = torch.load(path, map_location=torch.device('cpu'))
58
  model.load_state_dict(checkpoint)
59
  return model
 
53
  return model
54
 
55
 
56
+ def load_weights(model, path='model.pt', device_='cpu'):
57
+ checkpoint = torch.load(path, map_location=torch.device(device_))
58
  model.load_state_dict(checkpoint)
59
  return model