Spaces:
Runtime error
Runtime error
cocktailpeanut
commited on
Commit
•
4f4309e
1
Parent(s):
b563c74
update
Browse files
app.py
CHANGED
@@ -18,7 +18,7 @@ if torch.cuda.is_available():
|
|
18 |
net=net.cuda()
|
19 |
device = "cuda"
|
20 |
elif torch.backends.mps.is_available():
|
21 |
-
net.load_state_dict(torch.load(model_path))
|
22 |
net=net.to("mps")
|
23 |
device = "mps"
|
24 |
else:
|
@@ -45,10 +45,9 @@ def process(image):
|
|
45 |
im_tensor = torch.divide(im_tensor,255.0)
|
46 |
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
|
47 |
if device == "cuda":
|
|
|
48 |
elif device == "mps":
|
49 |
im_tensor=im_tensor.to("mps")
|
50 |
-
if torch.cuda.is_available():
|
51 |
-
im_tensor=im_tensor.cuda()
|
52 |
|
53 |
#inference
|
54 |
result=net(im_tensor)
|
|
|
18 |
net=net.cuda()
|
19 |
device = "cuda"
|
20 |
elif torch.backends.mps.is_available():
|
21 |
+
net.load_state_dict(torch.load(model_path,map_location="mps"))
|
22 |
net=net.to("mps")
|
23 |
device = "mps"
|
24 |
else:
|
|
|
45 |
im_tensor = torch.divide(im_tensor,255.0)
|
46 |
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
|
47 |
if device == "cuda":
|
48 |
+
im_tensor=im_tensor.cuda()
|
49 |
elif device == "mps":
|
50 |
im_tensor=im_tensor.to("mps")
|
|
|
|
|
51 |
|
52 |
#inference
|
53 |
result=net(im_tensor)
|