fffiloni commited on
Commit
855986f
1 Parent(s): 2252f3d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import tempfile
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from rembg import remove
7
+ import subprocess
8
+ from glob import glob
9
+
10
+ def remove_background(input_url):
11
+ # Create a temporary folder for downloaded and processed images
12
+ temp_dir = tempfile.mkdtemp()
13
+
14
+ # Download the image from the URL
15
+ image_path = os.path.join(temp_dir, 'input_image.png')
16
+ try:
17
+ image = Image.open(requests.get(input_url, stream=True).raw)
18
+ image.save(image_path)
19
+ except Exception as e:
20
+ shutil.rmtree(temp_dir)
21
+ return f"Error downloading or saving the image: {str(e)}"
22
+
23
+ # Run background removal
24
+ try:
25
+ removed_bg_path = os.path.join(temp_dir, 'output_image_rmbg.png')
26
+ img = Image.open(image_path)
27
+ result = remove(img)
28
+ result.save(removed_bg_path)
29
+ except Exception as e:
30
+ shutil.rmtree(temp_dir)
31
+ return f"Error removing background: {str(e)}"
32
+
33
+ return removed_bg_path, temp_dir
34
+
35
+ def run_inference(temp_dir):
36
+ # Define the inference configuration
37
+ inference_config = "configs/inference-768-6view.yaml"
38
+ pretrained_model = "pengHTYX/PSHuman_Unclip_768_6views"
39
+ crop_size = 740
40
+ seed = 600
41
+ num_views = 7
42
+ save_mode = "rgb"
43
+
44
+ try:
45
+ # Run the inference command
46
+ subprocess.run(
47
+ [
48
+ "python", "inference.py",
49
+ "--config", inference_config,
50
+ f"pretrained_model_name_or_path={pretrained_model}",
51
+ f"validation_dataset.crop_size={crop_size}",
52
+ f"with_smpl=false",
53
+ f"validation_dataset.root_dir={temp_dir}",
54
+ f"seed={seed}",
55
+ f"num_views={num_views}",
56
+ f"save_mode={save_mode}"
57
+ ],
58
+ check=True
59
+ )
60
+
61
+ # Collect the output images
62
+ output_images = glob(os.path.join(temp_dir, "*.png"))
63
+ return output_images
64
+ except subprocess.CalledProcessError as e:
65
+ return f"Error during inference: {str(e)}"
66
+
67
+ def process_image(input_url):
68
+ # Remove background
69
+ removed_bg_path, temp_dir = remove_background(input_url)
70
+
71
+ if isinstance(removed_bg_path, str) and removed_bg_path.startswith("Error"):
72
+ return removed_bg_path
73
+
74
+ # Run inference
75
+ output_images = run_inference(temp_dir)
76
+
77
+ if isinstance(output_images, str) and output_images.startswith("Error"):
78
+ shutil.rmtree(temp_dir)
79
+ return output_images
80
+
81
+ # Prepare outputs for display
82
+ results = []
83
+ for img_path in output_images:
84
+ results.append((img_path, img_path))
85
+
86
+ shutil.rmtree(temp_dir) # Cleanup temporary folder
87
+ return results
88
+
89
+ def gradio_interface():
90
+ with gr.Blocks() as app:
91
+ gr.Markdown("# Background Removal and Inference Pipeline")
92
+
93
+ with gr.Row():
94
+ input_url = gr.Textbox(label="Image URL", placeholder="Enter the URL of the image")
95
+ submit_button = gr.Button("Process")
96
+
97
+ output_gallery = gr.Gallery(label="Output Images").style(grid=[2], height="300px")
98
+
99
+ submit_button.click(process_image, inputs=[input_url], outputs=[output_gallery])
100
+
101
+ return app
102
+
103
+ # Launch the Gradio app
104
+ app = gradio_interface()
105
+ app.launch()