File size: 3,796 Bytes
855986f 3c57e49 855986f 7bb79b2 855986f 3c57e49 855986f e7e8daa 855986f 59299a4 855986f 59299a4 855986f 59299a4 855986f 0a3422c 855986f cf177c1 855986f cd9f57a 855986f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import os
import shutil
import tempfile
import gradio as gr
from PIL import Image
from rembg import remove
import sys
import subprocess
from glob import glob
import requests
# Ensure the required package is installed
def install_dependencies():
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/facebookresearch/pytorch3d.git@75ebeeaea0908c5527e7b1e305fbc7681382db47"])
except subprocess.CalledProcessError as e:
print(f"Error installing dependencies: {e}")
sys.exit(1) # Exit the script if installation fails
# Install dependencies at the start
install_dependencies()
def remove_background(input_url):
# Create a temporary folder for downloaded and processed images
temp_dir = tempfile.mkdtemp()
# Download the image from the URL
image_path = os.path.join(temp_dir, 'input_image.png')
try:
image = Image.open(input_url)
image.save(image_path)
except Exception as e:
shutil.rmtree(temp_dir)
return f"Error downloading or saving the image: {str(e)}"
# Run background removal
try:
removed_bg_path = os.path.join(temp_dir, 'output_image_rmbg.png')
img = Image.open(image_path)
result = remove(img)
result.save(removed_bg_path)
except Exception as e:
shutil.rmtree(temp_dir)
return f"Error removing background: {str(e)}"
return removed_bg_path, temp_dir
def run_inference(temp_dir):
# Define the inference configuration
inference_config = "configs/inference-768-6view.yaml"
pretrained_model = "pengHTYX/PSHuman_Unclip_768_6views"
crop_size = 740
seed = 600
num_views = 7
save_mode = "rgb"
try:
# Run the inference command
subprocess.run(
[
"python", "inference.py",
"--config", inference_config,
f"pretrained_model_name_or_path={pretrained_model}",
f"validation_dataset.crop_size={crop_size}",
f"with_smpl=false",
f"validation_dataset.root_dir={temp_dir}",
f"seed={seed}",
f"num_views={num_views}",
f"save_mode={save_mode}"
],
check=True
)
# Collect the output images
output_images = glob(os.path.join(temp_dir, "*.png"))
return output_images
except subprocess.CalledProcessError as e:
return f"Error during inference: {str(e)}"
def process_image(input_url):
# Remove background
result = remove_background(input_url)
if isinstance(result, str) and result.startswith("Error"):
raise gr.Error(f"{result}") # Return the error message if something went wrong
removed_bg_path, temp_dir = result # Unpack only if successful
# Run inference
output_images = run_inference(temp_dir)
if isinstance(output_images, str) and output_images.startswith("Error"):
shutil.rmtree(temp_dir)
raise gr.Error(f"{output_images}") # Return the error message if inference failed
# Prepare outputs for display
results = []
for img_path in output_images:
results.append((img_path, img_path))
shutil.rmtree(temp_dir) # Cleanup temporary folder
return results
def gradio_interface():
with gr.Blocks() as app:
gr.Markdown("# Background Removal and Inference Pipeline")
with gr.Row():
input_image = gr.Image(label="Image input", type="filepath")
submit_button = gr.Button("Process")
output_gallery = gr.Gallery(label="Output Images")
submit_button.click(process_image, inputs=[input_image], outputs=[output_gallery])
return app
# Launch the Gradio app
app = gradio_interface()
app.launch()
|