Pinwheel commited on
Commit
364c262
β€’
1 Parent(s): 13a42ba

Add CPU support

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -51,8 +51,11 @@ def save_merged_boxes(predictions, image_np):
51
  return roi
52
 
53
  # Load the EfficientNet model
54
- def load_model(model_path):
55
- model = torch.load(model_path)
 
 
 
56
  model = model.to(device)
57
  model.eval() # Set the model to evaluation mode
58
  return model
 
51
  return roi
52
 
53
  # Load the EfficientNet model
54
+ def load_model(model_path, map_location=None):
55
+ if map_location:
56
+ model = torch.load(model_path, map_location=map_location)
57
+ else:
58
+ model = torch.load(model_path)
59
  model = model.to(device)
60
  model.eval() # Set the model to evaluation mode
61
  return model