cocktailpeanut commited on
Commit
bc3d254
·
1 Parent(s): 22b8c91
Files changed (2) hide show
  1. app.py +6 -5
  2. requirements.txt +2 -2
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from loadimg import load_img
3
- import spaces
4
  from transformers import AutoModelForImageSegmentation
5
  import torch
6
  from torchvision import transforms
@@ -11,10 +11,11 @@ import numpy as np
11
  import os
12
  import tempfile
13
  import uuid
 
14
 
15
  torch.set_float32_matmul_precision("medium")
16
 
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
  birefnet = AutoModelForImageSegmentation.from_pretrained(
20
  "ZhengPeng7/BiRefNet", trust_remote_code=True
@@ -29,7 +30,7 @@ transform_image = transforms.Compose(
29
  )
30
 
31
 
32
- @spaces.GPU
33
  def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=0, video_handling="slow_down"):
34
  try:
35
  # Load the video using moviepy
@@ -111,7 +112,7 @@ def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=
111
 
112
  def process(image, bg):
113
  image_size = image.size
114
- input_images = transform_image(image).unsqueeze(0).to("cuda")
115
  # Prediction
116
  with torch.no_grad():
117
  preds = birefnet(input_images)[-1].sigmoid().cpu()
@@ -191,4 +192,4 @@ with gr.Blocks(theme=gr.themes.Ocean()) as demo:
191
  )
192
 
193
  if __name__ == "__main__":
194
- demo.launch(show_error=True)
 
1
  import gradio as gr
2
  from loadimg import load_img
3
+ #import spaces
4
  from transformers import AutoModelForImageSegmentation
5
  import torch
6
  from torchvision import transforms
 
11
  import os
12
  import tempfile
13
  import uuid
14
+ import devicetorch
15
 
16
  torch.set_float32_matmul_precision("medium")
17
 
18
+ device = devicetorch.get(torch)
19
 
20
  birefnet = AutoModelForImageSegmentation.from_pretrained(
21
  "ZhengPeng7/BiRefNet", trust_remote_code=True
 
30
  )
31
 
32
 
33
+ #@spaces.GPU
34
  def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=0, video_handling="slow_down"):
35
  try:
36
  # Load the video using moviepy
 
112
 
113
  def process(image, bg):
114
  image_size = image.size
115
+ input_images = transform_image(image).unsqueeze(0).to(device)
116
  # Prediction
117
  with torch.no_grad():
118
  preds = birefnet(input_images)[-1].sigmoid().cpu()
 
192
  )
193
 
194
  if __name__ == "__main__":
195
+ demo.launch(show_error=True)
requirements.txt CHANGED
@@ -1,7 +1,6 @@
1
  torch
2
  accelerate
3
  opencv-python
4
- spaces
5
  pillow
6
  numpy
7
  timm
@@ -15,4 +14,5 @@ gradio
15
  gradio_imageslider
16
  loadimg>=0.1.1
17
  moviepy
18
- pydub
 
 
1
  torch
2
  accelerate
3
  opencv-python
 
4
  pillow
5
  numpy
6
  timm
 
14
  gradio_imageslider
15
  loadimg>=0.1.1
16
  moviepy
17
+ pydub
18
+ devicetorch