cocktailpeanut commited on
Commit
b563c74
1 Parent(s): 4941fcb
Files changed (2) hide show
  1. app.py +9 -1
  2. requirements.txt +3 -3
app.py CHANGED
@@ -16,6 +16,11 @@ model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
16
  if torch.cuda.is_available():
17
  net.load_state_dict(torch.load(model_path))
18
  net=net.cuda()
 
 
 
 
 
19
  else:
20
  net.load_state_dict(torch.load(model_path,map_location="cpu"))
21
  net.eval()
@@ -39,6 +44,9 @@ def process(image):
39
  im_tensor = torch.unsqueeze(im_tensor,0)
40
  im_tensor = torch.divide(im_tensor,255.0)
41
  im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
 
 
 
42
  if torch.cuda.is_available():
43
  im_tensor=im_tensor.cuda()
44
 
@@ -103,4 +111,4 @@ examples = [['./input.jpg'],]
103
  demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
104
 
105
  if __name__ == "__main__":
106
- demo.launch(share=False)
 
16
  if torch.cuda.is_available():
17
  net.load_state_dict(torch.load(model_path))
18
  net=net.cuda()
19
+ device = "cuda"
20
+ elif torch.backends.mps.is_available():
21
+ net.load_state_dict(torch.load(model_path))
22
+ net=net.to("mps")
23
+ device = "mps"
24
  else:
25
  net.load_state_dict(torch.load(model_path,map_location="cpu"))
26
  net.eval()
 
44
  im_tensor = torch.unsqueeze(im_tensor,0)
45
  im_tensor = torch.divide(im_tensor,255.0)
46
  im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
47
+ if device == "cuda":
48
+ elif device == "mps":
49
+ im_tensor=im_tensor.to("mps")
50
  if torch.cuda.is_available():
51
  im_tensor=im_tensor.cuda()
52
 
 
111
  demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
112
 
113
  if __name__ == "__main__":
114
+ demo.launch(share=False)
requirements.txt CHANGED
@@ -1,9 +1,9 @@
1
  gradio
2
  gradio_imageslider
3
- torch
4
- torchvision
5
  pillow
6
  numpy
7
  typing
8
  gitpython
9
- huggingface_hub
 
1
  gradio
2
  gradio_imageslider
3
+ #torch
4
+ #torchvision
5
  pillow
6
  numpy
7
  typing
8
  gitpython
9
+ huggingface_hub