|
import torch |
|
import os |
|
import tarfile |
|
import requests |
|
import shutil |
|
import tempfile |
|
import gradio as gr |
|
from PIL import Image |
|
from rembg import remove |
|
import sys |
|
import subprocess |
|
from glob import glob |
|
import requests |
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
onedrive_url = "https://hkustconnect-my.sharepoint.com/:u:/g/personal/plibp_connect_ust_hk/EZQphP-2y5BGhEIe8jb03i4BIcqiJ2mUW2JmGC5s0VKOdw?e=qVzBBD" |
|
destination_tar = "smpl_related.tar.gz" |
|
destination_folder = "smpl_related" |
|
|
|
|
|
def download_file(url, destination): |
|
print(f"Downloading {url} to {destination}...") |
|
response = requests.get(url, stream=True) |
|
if response.status_code == 200: |
|
with open(destination, 'wb') as f: |
|
f.write(response.content) |
|
print(f"Downloaded file to {destination}") |
|
else: |
|
raise Exception(f"Failed to download file. Status code: {response.status_code}") |
|
|
|
|
|
def extract_tar(file_path, extract_to): |
|
print(f"Extracting {file_path} to {extract_to}...") |
|
with tarfile.open(file_path, "r:gz") as tar: |
|
tar.extractall(path=extract_to) |
|
print(f"Extraction completed.") |
|
|
|
|
|
if not os.path.exists(destination_folder): |
|
try: |
|
|
|
download_file(onedrive_url, destination_tar) |
|
|
|
|
|
extract_tar(destination_tar, "./") |
|
|
|
|
|
os.remove(destination_tar) |
|
print(f"Cleaned up the tar file: {destination_tar}") |
|
|
|
except Exception as e: |
|
print(f"An error occurred: {e}") |
|
else: |
|
print(f"Folder {destination_folder} already exists. Skipping download and extraction.") |
|
|
|
|
|
os.makedirs("ckpts", exist_ok=True) |
|
|
|
snapshot_download( |
|
repo_id = "pengHTYX/PSHuman_Unclip_768_6views", |
|
local_dir = "./ckpts" |
|
) |
|
|
|
|
|
def remove_background(input_url): |
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
|
|
|
image_path = os.path.join(temp_dir, 'input_image.png') |
|
try: |
|
image = Image.open(input_url).convert("RGBA") |
|
image.save(image_path) |
|
except Exception as e: |
|
shutil.rmtree(temp_dir) |
|
return f"Error downloading or saving the image: {str(e)}" |
|
|
|
""" |
|
# Run background removal |
|
try: |
|
removed_bg_path = os.path.join(temp_dir, 'output_image_rmbg.png') |
|
img = Image.open(image_path) |
|
result = remove(img) |
|
result.save(removed_bg_path) |
|
except Exception as e: |
|
shutil.rmtree(temp_dir) |
|
return f"Error removing background: {str(e)}" |
|
|
|
return removed_bg_path, temp_dir |
|
""" |
|
return image_path, temp_dir |
|
|
|
def run_inference(temp_dir): |
|
|
|
inference_config = "configs/inference-768-6view.yaml" |
|
pretrained_model = "./ckpts" |
|
crop_size = 740 |
|
seed = 600 |
|
num_views = 7 |
|
save_mode = "rgb" |
|
|
|
try: |
|
|
|
subprocess.run( |
|
[ |
|
"python", "inference.py", |
|
"--config", inference_config, |
|
f"pretrained_model_name_or_path={pretrained_model}", |
|
f"validation_dataset.crop_size={crop_size}", |
|
f"with_smpl=false", |
|
f"validation_dataset.root_dir={temp_dir}", |
|
f"seed={seed}", |
|
f"num_views={num_views}", |
|
f"save_mode={save_mode}" |
|
], |
|
check=True |
|
) |
|
|
|
|
|
output_images = glob(os.path.join(temp_dir, "*.png")) |
|
return output_images |
|
except subprocess.CalledProcessError as e: |
|
return f"Error during inference: {str(e)}" |
|
|
|
def process_image(input_url): |
|
|
|
result = remove_background(input_url) |
|
|
|
if isinstance(result, str) and result.startswith("Error"): |
|
raise gr.Error(f"{result}") |
|
|
|
removed_bg_path, temp_dir = result |
|
|
|
|
|
output_images = run_inference(temp_dir) |
|
|
|
if isinstance(output_images, str) and output_images.startswith("Error"): |
|
shutil.rmtree(temp_dir) |
|
raise gr.Error(f"{output_images}") |
|
|
|
|
|
results = [] |
|
for img_path in output_images: |
|
results.append((img_path, img_path)) |
|
|
|
|
|
return results |
|
|
|
def gradio_interface(): |
|
with gr.Blocks() as app: |
|
gr.Markdown("# Background Removal and Inference Pipeline") |
|
|
|
with gr.Row(): |
|
input_image = gr.Image(label="Image input", type="filepath") |
|
submit_button = gr.Button("Process") |
|
|
|
output_gallery = gr.Gallery(label="Output Images") |
|
|
|
submit_button.click(process_image, inputs=[input_image], outputs=[output_gallery]) |
|
|
|
return app |
|
|
|
|
|
app = gradio_interface() |
|
app.launch() |
|
|