haritsahm
commited on
Commit
·
ce29101
1
Parent(s):
4152b54
Add torch device selection
Browse files
main.py
CHANGED
@@ -10,13 +10,17 @@ from PIL import Image
|
|
10 |
from models import phc_models
|
11 |
from utils import utils
|
12 |
|
|
|
|
|
|
|
|
|
13 |
BILATERIAL_WEIGHT = 'weights/phresnet18_cbis2views.pt'
|
14 |
BILATERAL_MODEL = phc_models.PHCResNet18(
|
15 |
channels=2, n=2, num_classes=1, visualize=True)
|
16 |
BILATERAL_MODEL.add_top_blocks(num_classes=1)
|
17 |
BILATERAL_MODEL.load_state_dict(torch.load(
|
18 |
BILATERIAL_WEIGHT, map_location='cpu'))
|
19 |
-
BILATERAL_MODEL = BILATERAL_MODEL.to(
|
20 |
BILATERAL_MODEL.eval()
|
21 |
INPUT_HEIGHT, INPUT_WIDTH = 600, 500
|
22 |
|
@@ -143,6 +147,7 @@ def predict_bilateral(cc_file, mlo_file):
|
|
143 |
|
144 |
images_t = torch.from_numpy(images)
|
145 |
images_t = images_t.unsqueeze(0) # Add batch dimension
|
|
|
146 |
|
147 |
out, _, out_refiner = BILATERAL_MODEL(images_t)
|
148 |
|
|
|
10 |
from models import phc_models
|
11 |
from utils import utils
|
12 |
|
13 |
+
device = torch.device('cpu')
|
14 |
+
if torch.cuda.is_available():
|
15 |
+
device = torch.device('cuda:0')
|
16 |
+
|
17 |
BILATERIAL_WEIGHT = 'weights/phresnet18_cbis2views.pt'
|
18 |
BILATERAL_MODEL = phc_models.PHCResNet18(
|
19 |
channels=2, n=2, num_classes=1, visualize=True)
|
20 |
BILATERAL_MODEL.add_top_blocks(num_classes=1)
|
21 |
BILATERAL_MODEL.load_state_dict(torch.load(
|
22 |
BILATERIAL_WEIGHT, map_location='cpu'))
|
23 |
+
BILATERAL_MODEL = BILATERAL_MODEL.to(device)
|
24 |
BILATERAL_MODEL.eval()
|
25 |
INPUT_HEIGHT, INPUT_WIDTH = 600, 500
|
26 |
|
|
|
147 |
|
148 |
images_t = torch.from_numpy(images)
|
149 |
images_t = images_t.unsqueeze(0) # Add batch dimension
|
150 |
+
images_t = images_t.to(device)
|
151 |
|
152 |
out, _, out_refiner = BILATERAL_MODEL(images_t)
|
153 |
|