Spaces:
Runtime error
Runtime error
cocktailpeanut
commited on
Commit
•
b563c74
1
Parent(s):
4941fcb
update
Browse files- app.py +9 -1
- requirements.txt +3 -3
app.py
CHANGED
@@ -16,6 +16,11 @@ model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
|
|
16 |
if torch.cuda.is_available():
|
17 |
net.load_state_dict(torch.load(model_path))
|
18 |
net=net.cuda()
|
|
|
|
|
|
|
|
|
|
|
19 |
else:
|
20 |
net.load_state_dict(torch.load(model_path,map_location="cpu"))
|
21 |
net.eval()
|
@@ -39,6 +44,9 @@ def process(image):
|
|
39 |
im_tensor = torch.unsqueeze(im_tensor,0)
|
40 |
im_tensor = torch.divide(im_tensor,255.0)
|
41 |
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
|
|
|
|
|
|
|
42 |
if torch.cuda.is_available():
|
43 |
im_tensor=im_tensor.cuda()
|
44 |
|
@@ -103,4 +111,4 @@ examples = [['./input.jpg'],]
|
|
103 |
demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
|
104 |
|
105 |
if __name__ == "__main__":
|
106 |
-
demo.launch(share=False)
|
|
|
16 |
if torch.cuda.is_available():
|
17 |
net.load_state_dict(torch.load(model_path))
|
18 |
net=net.cuda()
|
19 |
+
device = "cuda"
|
20 |
+
elif torch.backends.mps.is_available():
|
21 |
+
net.load_state_dict(torch.load(model_path))
|
22 |
+
net=net.to("mps")
|
23 |
+
device = "mps"
|
24 |
else:
|
25 |
net.load_state_dict(torch.load(model_path,map_location="cpu"))
|
26 |
net.eval()
|
|
|
44 |
im_tensor = torch.unsqueeze(im_tensor,0)
|
45 |
im_tensor = torch.divide(im_tensor,255.0)
|
46 |
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
|
47 |
+
if device == "cuda":
|
48 |
+
elif device == "mps":
|
49 |
+
im_tensor=im_tensor.to("mps")
|
50 |
if torch.cuda.is_available():
|
51 |
im_tensor=im_tensor.cuda()
|
52 |
|
|
|
111 |
demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples, title=title, description=description)
|
112 |
|
113 |
if __name__ == "__main__":
|
114 |
+
demo.launch(share=False)
|
requirements.txt
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
gradio
|
2 |
gradio_imageslider
|
3 |
-
torch
|
4 |
-
torchvision
|
5 |
pillow
|
6 |
numpy
|
7 |
typing
|
8 |
gitpython
|
9 |
-
huggingface_hub
|
|
|
1 |
gradio
|
2 |
gradio_imageslider
|
3 |
+
#torch
|
4 |
+
#torchvision
|
5 |
pillow
|
6 |
numpy
|
7 |
typing
|
8 |
gitpython
|
9 |
+
huggingface_hub
|