import gradio as gr import torch from carvekit.api.interface import Interface from carvekit.ml.wrap.fba_matting import FBAMatting from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 from carvekit.pipelines.postprocessing import MattingMethod from carvekit.pipelines.preprocessing import PreprocessingStub from carvekit.trimap.generator import TrimapGenerator device = 'cuda' if torch.cuda.is_available() else 'cpu' # Check doc strings for more information seg_net = TracerUniversalB7(device=device, batch_size=1) fba = FBAMatting(device=device, input_tensor_size=2048, batch_size=1) trimap = TrimapGenerator() preprocessing = PreprocessingStub() postprocessing = MattingMethod(matting_module=fba, trimap_generator=trimap, device=device) interface = Interface(pre_pipe=preprocessing, post_pipe=postprocessing, seg_pipe=seg_net) def predict(image): return interface([image])[0] footer = r"""
Demo based on CarveKit
""" with gr.Blocks(title="CarveKit") as app: gr.HTML("

CarveKit

") gr.HTML("

Automated high-quality background removal framework for an image using neural networks.

") with gr.Row(): with gr.Column(): input_img = gr.Image(type="pil", label="Input image") run_btn = gr.Button(variant="primary") with gr.Column(): output_img = gr.Image(type="pil", label="result") run_btn.click(predict, [input_img], [output_img]) with gr.Row(): examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)] examples = gr.Dataset(components=[input_img], samples=examples_data) examples.click(lambda x: x[0], [examples], [input_img]) with gr.Row(): gr.HTML(footer) app.launch(share=False, debug=True, enable_queue=True, show_error=True)