cocktailpeanut commited on
Commit
4f4309e
1 Parent(s): b563c74
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -18,7 +18,7 @@ if torch.cuda.is_available():
18
  net=net.cuda()
19
  device = "cuda"
20
  elif torch.backends.mps.is_available():
21
- net.load_state_dict(torch.load(model_path))
22
  net=net.to("mps")
23
  device = "mps"
24
  else:
@@ -45,10 +45,9 @@ def process(image):
45
  im_tensor = torch.divide(im_tensor,255.0)
46
  im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
47
  if device == "cuda":
 
48
  elif device == "mps":
49
  im_tensor=im_tensor.to("mps")
50
- if torch.cuda.is_available():
51
- im_tensor=im_tensor.cuda()
52
 
53
  #inference
54
  result=net(im_tensor)
 
18
  net=net.cuda()
19
  device = "cuda"
20
  elif torch.backends.mps.is_available():
21
+ net.load_state_dict(torch.load(model_path,map_location="mps"))
22
  net=net.to("mps")
23
  device = "mps"
24
  else:
 
45
  im_tensor = torch.divide(im_tensor,255.0)
46
  im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
47
  if device == "cuda":
48
+ im_tensor=im_tensor.cuda()
49
  elif device == "mps":
50
  im_tensor=im_tensor.to("mps")
 
 
51
 
52
  #inference
53
  result=net(im_tensor)