leonelhs commited on
Commit
699e9ad
·
1 Parent(s): 03bfe96

add net choices

Browse files
Files changed (1) hide show
  1. app.py +29 -19
app.py CHANGED
@@ -1,16 +1,23 @@
1
  import gradio as gr
2
  import torch
3
  from carvekit.api.interface import Interface
 
 
4
  from carvekit.ml.wrap.fba_matting import FBAMatting
5
  from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
 
6
  from carvekit.pipelines.postprocessing import MattingMethod
7
  from carvekit.pipelines.preprocessing import PreprocessingStub
8
  from carvekit.trimap.generator import TrimapGenerator
9
 
10
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
 
12
- # Check doc strings for more information
13
- seg_net = TracerUniversalB7(device=device, batch_size=1)
 
 
 
 
14
 
15
  fba = FBAMatting(device=device,
16
  input_tensor_size=2048,
@@ -24,18 +31,19 @@ postprocessing = MattingMethod(matting_module=fba,
24
  trimap_generator=trimap,
25
  device=device)
26
 
27
- interface = Interface(pre_pipe=preprocessing,
28
- post_pipe=postprocessing,
29
- seg_pipe=seg_net)
30
 
31
 
32
- def generate_trimap(original):
33
- mask = seg_net([original])
34
  return trimap(original_image=original, mask=mask[0])
35
 
36
 
37
- def predict(image):
38
- return interface([image])[0]
 
 
 
39
 
40
 
41
  footer = r"""
@@ -49,35 +57,37 @@ Demo based on <a href='https://github.com/OPHoperHPO/image-background-remove-too
49
  """
50
 
51
  with gr.Blocks(title="CarveKit") as app:
52
-
53
  gr.Markdown("<center><h1><b>CarveKit</b></h1></center>")
54
  gr.HTML("<center><h3>High-quality image background removal</h3></center>")
55
 
56
  with gr.Tabs() as tabs:
57
  with gr.TabItem("Remove background", id=0):
58
- with gr.Row().style(equal_height=False):
59
  with gr.Column():
60
  input_img = gr.Image(type="pil", label="Input image")
 
 
 
 
61
  run_btn = gr.Button(variant="primary")
62
  with gr.Column():
63
  output_img = gr.Image(type="pil", label="result")
64
 
65
- run_btn.click(predict, [input_img], [output_img])
66
 
67
  with gr.TabItem("Trimap generator", id=1):
68
- with gr.Row().style(equal_height=False):
69
  with gr.Column():
70
  trimap_input = gr.Image(type="pil", label="Input image")
 
 
 
 
71
  trimap_btn = gr.Button(variant="primary")
72
  with gr.Column():
73
  trimap_output = gr.Image(type="pil", label="result")
74
 
75
- trimap_btn.click(generate_trimap, [trimap_input], [trimap_output])
76
-
77
- # with gr.Row():
78
- # examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)]
79
- # examples = gr.Dataset(components=[input_img], samples=examples_data)
80
- # examples.click(lambda x: x[0], [examples], [input_img])
81
 
82
  with gr.Row():
83
  gr.HTML(footer)
 
1
  import gradio as gr
2
  import torch
3
  from carvekit.api.interface import Interface
4
+ from carvekit.ml.wrap.basnet import BASNET
5
+ from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
6
  from carvekit.ml.wrap.fba_matting import FBAMatting
7
  from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
8
+ from carvekit.ml.wrap.u2net import U2NET
9
  from carvekit.pipelines.postprocessing import MattingMethod
10
  from carvekit.pipelines.preprocessing import PreprocessingStub
11
  from carvekit.trimap.generator import TrimapGenerator
12
 
13
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
 
15
+ segment_net = {
16
+ "U2NET": U2NET(device=device, batch_size=1),
17
+ "BASNET": BASNET(device=device, batch_size=1),
18
+ "DeepLabV3": DeepLabV3(device=device, batch_size=1),
19
+ "TracerUniversalB7": TracerUniversalB7(device=device, batch_size=1)
20
+ }
21
 
22
  fba = FBAMatting(device=device,
23
  input_tensor_size=2048,
 
31
  trimap_generator=trimap,
32
  device=device)
33
 
34
+ method_choices = [k for k, v in segment_net.items()]
 
 
35
 
36
 
37
+ def generate_trimap(method, original):
38
+ mask = segment_net[method]([original])
39
  return trimap(original_image=original, mask=mask[0])
40
 
41
 
42
+ def predict(method, image):
43
+ method = segment_net[method]
44
+ return Interface(pre_pipe=preprocessing,
45
+ post_pipe=postprocessing,
46
+ seg_pipe=method)([image])[0]
47
 
48
 
49
  footer = r"""
 
57
  """
58
 
59
  with gr.Blocks(title="CarveKit") as app:
 
60
  gr.Markdown("<center><h1><b>CarveKit</b></h1></center>")
61
  gr.HTML("<center><h3>High-quality image background removal</h3></center>")
62
 
63
  with gr.Tabs() as tabs:
64
  with gr.TabItem("Remove background", id=0):
65
+ with gr.Row(equal_height=False):
66
  with gr.Column():
67
  input_img = gr.Image(type="pil", label="Input image")
68
+ drp_itf = gr.Dropdown(
69
+ value="TracerUniversalB7",
70
+ label="Segmentor model",
71
+ choices=method_choices)
72
  run_btn = gr.Button(variant="primary")
73
  with gr.Column():
74
  output_img = gr.Image(type="pil", label="result")
75
 
76
+ run_btn.click(predict, [drp_itf, input_img], [output_img])
77
 
78
  with gr.TabItem("Trimap generator", id=1):
79
+ with gr.Row(equal_height=False):
80
  with gr.Column():
81
  trimap_input = gr.Image(type="pil", label="Input image")
82
+ drp_itf = gr.Dropdown(
83
+ value="TracerUniversalB7",
84
+ label="Segmentor model",
85
+ choices=method_choices)
86
  trimap_btn = gr.Button(variant="primary")
87
  with gr.Column():
88
  trimap_output = gr.Image(type="pil", label="result")
89
 
90
+ trimap_btn.click(generate_trimap, [drp_itf, trimap_input], [trimap_output])
 
 
 
 
 
91
 
92
  with gr.Row():
93
  gr.HTML(footer)