Spaces:
Running
Running
import os | |
import re | |
import shutil | |
import time | |
from types import SimpleNamespace | |
from typing import Any | |
import gradio as gr | |
import numpy as np | |
from detectron2 import engine | |
from PIL import Image | |
from inference import main, setup_cfg | |
# internal settings | |
NUM_PROCESSES = 1 | |
CROP = True | |
SCORE_THRESHOLD = 0.8 | |
MAX_PARTS = 5 | |
ARGS = SimpleNamespace( | |
config_file="configs/coco/instance-segmentation/swin/opd_v1_real.yaml", | |
model="../data/models/motion_state_pred_opdformerp_rgb.pth", | |
input_format="RGB", | |
output=".output", | |
cpu=True, | |
) | |
outputs = [] | |
def predict(rgb_image: str, depth_image: str, intrinsics: np.ndarray, num_samples: int) -> list[Any]: | |
global outputs | |
def find_gifs(path: str) -> list[str]: | |
"""Scrape folders for all generated gif files.""" | |
for file in os.listdir(path): | |
sub_path = os.path.join(path, file) | |
if os.path.isdir(sub_path): | |
for image_file in os.listdir(sub_path): | |
if re.match(r".*\.gif$", image_file): | |
yield os.path.join(sub_path, image_file) | |
def find_images(path: str) -> list[str]: | |
"""Scrape folders for all generated gif files.""" | |
images = {} | |
for file in os.listdir(path): | |
sub_path = os.path.join(path, file) | |
if os.path.isdir(sub_path): | |
images[file] = [] | |
for image_file in sorted(os.listdir(sub_path)): | |
if re.match(r".*\.png$", image_file): | |
images[file].append(os.path.join(sub_path, image_file)) | |
return images | |
def get_generator(images): | |
def gen(): | |
while True: | |
for im in images: | |
time.sleep(0.025) | |
yield im | |
time.sleep(3) | |
return gen | |
# clear old predictions | |
for path in os.listdir(ARGS.output): | |
full_path = os.path.join(ARGS.output, path) | |
if os.path.isdir(full_path): | |
shutil.rmtree(full_path) | |
else: | |
os.remove(full_path) | |
cfg = setup_cfg(ARGS) | |
engine.launch( | |
main, | |
NUM_PROCESSES, | |
args=( | |
cfg, | |
rgb_image, | |
depth_image, | |
intrinsics, | |
num_samples, | |
CROP, | |
SCORE_THRESHOLD, | |
), | |
) | |
# process output | |
# TODO: may want to select these in decreasing order of score | |
image_files = find_images(ARGS.output) | |
output = [] | |
for count, part in enumerate(image_files): | |
if count < MAX_PARTS: | |
# output.append(gr.update(value=get_generator([Image.open(im) for im in image_files[part]]), visible=True)) | |
output.append(get_generator([Image.open(im) for im in image_files[part]])) | |
# while len(output) < MAX_PARTS: | |
# output.append(gr.update(visible=False)) | |
yield from output[0]() | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# OPDMulti Demo | |
Upload an image to see its range of motion. | |
""" | |
) | |
# TODO: add gr.Examples | |
with gr.Row(): | |
rgb_image = gr.Image( | |
image_mode="RGB", source="upload", type="filepath", label="RGB Image", show_label=True, interactive=True | |
) | |
depth_image = gr.Image( | |
image_mode="I;16", source="upload", type="filepath", label="Depth Image", show_label=True, interactive=True | |
) | |
intrinsics = gr.Dataframe( | |
value=[ | |
[ | |
214.85935872395834, | |
0.0, | |
125.90160319010417, | |
], | |
[ | |
0.0, | |
214.85935872395834, | |
95.13726399739583, | |
], | |
[ | |
0.0, | |
0.0, | |
1.0, | |
], | |
], | |
row_count=(3, "fixed"), | |
col_count=(3, "fixed"), | |
datatype="number", | |
type="numpy", | |
label="Intrinsics matrix", | |
show_label=True, | |
interactive=True, | |
) | |
num_samples = gr.Number( | |
value=10, | |
label="Number of samples", | |
show_label=True, | |
interactive=True, | |
precision=0, | |
minimum=3, | |
maximum=20, | |
) | |
submit_btn = gr.Button("Run model") | |
# TODO: do we want to set a maximum limit on how many parts we render? We could also show the number of components | |
# identified. | |
# images = [gr.Image(type="pil", label=f"Part {idx + 1}", visible=False) for idx in range(MAX_PARTS)] | |
image = gr.Image(type="pil", visible=True) | |
# TODO: maybe need to use a queue here so we don't overload the instance | |
submit_btn.click( | |
fn=predict, inputs=[rgb_image, depth_image, intrinsics, num_samples], outputs=image, api_name="run_model" | |
) | |
demo.queue(api_open=False) | |
demo.launch() | |