haritsahm commited on
Commit
ce29101
·
1 Parent(s): 4152b54

Add torch device selection

Browse files
Files changed (1) hide show
  1. main.py +6 -1
main.py CHANGED
@@ -10,13 +10,17 @@ from PIL import Image
10
  from models import phc_models
11
  from utils import utils
12
 
 
 
 
 
13
  BILATERIAL_WEIGHT = 'weights/phresnet18_cbis2views.pt'
14
  BILATERAL_MODEL = phc_models.PHCResNet18(
15
  channels=2, n=2, num_classes=1, visualize=True)
16
  BILATERAL_MODEL.add_top_blocks(num_classes=1)
17
  BILATERAL_MODEL.load_state_dict(torch.load(
18
  BILATERIAL_WEIGHT, map_location='cpu'))
19
- BILATERAL_MODEL = BILATERAL_MODEL.to('cpu')
20
  BILATERAL_MODEL.eval()
21
  INPUT_HEIGHT, INPUT_WIDTH = 600, 500
22
 
@@ -143,6 +147,7 @@ def predict_bilateral(cc_file, mlo_file):
143
 
144
  images_t = torch.from_numpy(images)
145
  images_t = images_t.unsqueeze(0) # Add batch dimension
 
146
 
147
  out, _, out_refiner = BILATERAL_MODEL(images_t)
148
 
 
10
  from models import phc_models
11
  from utils import utils
12
 
13
+ device = torch.device('cpu')
14
+ if torch.cuda.is_available():
15
+ device = torch.device('cuda:0')
16
+
17
  BILATERIAL_WEIGHT = 'weights/phresnet18_cbis2views.pt'
18
  BILATERAL_MODEL = phc_models.PHCResNet18(
19
  channels=2, n=2, num_classes=1, visualize=True)
20
  BILATERAL_MODEL.add_top_blocks(num_classes=1)
21
  BILATERAL_MODEL.load_state_dict(torch.load(
22
  BILATERIAL_WEIGHT, map_location='cpu'))
23
+ BILATERAL_MODEL = BILATERAL_MODEL.to(device)
24
  BILATERAL_MODEL.eval()
25
  INPUT_HEIGHT, INPUT_WIDTH = 600, 500
26
 
 
147
 
148
  images_t = torch.from_numpy(images)
149
  images_t = images_t.unsqueeze(0) # Add batch dimension
150
+ images_t = images_t.to(device)
151
 
152
  out, _, out_refiner = BILATERAL_MODEL(images_t)
153