Spaces:
Running
Running
File size: 10,730 Bytes
01664b3 5ceacf4 01664b3 17456cf 01664b3 20c01c5 01664b3 17456cf 5ceacf4 01664b3 6d737eb 01664b3 17456cf 01664b3 3c7feee 01664b3 6d737eb 01664b3 17456cf 5ceacf4 92d915f 17456cf 5ceacf4 17456cf 5ceacf4 17456cf 3c7feee 5ceacf4 92d915f 17456cf 01664b3 17456cf 01664b3 5ceacf4 01664b3 92d915f 5ceacf4 17456cf 92d915f 6d737eb 92d915f 6d737eb 17456cf 92d915f 17456cf 92d915f 17456cf 6d737eb 17456cf 6d737eb 92d915f 17456cf 6d737eb 17456cf 01664b3 6d737eb 92d915f 6d737eb 01664b3 92d915f 17456cf 92d915f 5ceacf4 01664b3 17456cf 01664b3 17456cf 01664b3 5ceacf4 01664b3 17456cf 01664b3 5ceacf4 01664b3 5ceacf4 01664b3 5ceacf4 01664b3 17456cf 01664b3 6d737eb 01664b3 17456cf 6d737eb 01664b3 17456cf 92d915f 6d737eb 92d915f 17456cf 20c01c5 01664b3 17456cf 01664b3 20c01c5 5ceacf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 |
import os
import re
import shutil
import time
from types import SimpleNamespace
from typing import Any, Callable, Generator, Optional
import gradio as gr
import numpy as np
from detectron2 import engine
from natsort import natsorted
from PIL import Image
from inference import main, setup_cfg
# internal settings
NUM_PROCESSES = 1
CROP = False
SCORE_THRESHOLD = 0.8
MAX_PARTS = 5 # TODO: we can replace this by having a slider and a single image visualization component rather than multiple components
ARGS = SimpleNamespace(
config_file="configs/coco/instance-segmentation/swin/opd_v1_real.yaml",
model=".data/models/motion_state_pred_opdformerp_rgb.pth",
input_format="RGB",
output=".output",
cpu=True,
)
NUM_SAMPLES = 10
# this variable holds the current state of results, as the user will need to be able to "reload" the results in order
# to visualize the demo again. The output images are cached by the temporary path of the image, meaning that multiple
# users should be able to simultaneously run the demo. Gradio should be able to handle the case where multiple distinct
# images are uploaded with the same name, as I believe the caching of temp path is based on base64 encoding, not the
# filename itself.
# TODO: right now there is no gc system for outputs, which means if there is enough traffic per unit time such that the
# outputs are all generated on the same system instantiation of the code, the RAM could max out, acknowledging also that
# this is not designed to run on GPU and so the model and all will also need to be stored in CPU memory. Solutions could
# include
# 1. a caching design to remove old results periodically, especially if the image is reset;
# 2. caching results on disk rather than in memory, since the cap is higher; or
# 3. figuring out some way to cache results in browser instead of in the backend (couldn't figure out a way to do this
# earlier.
outputs: dict[str, list[list[Image.Image]]] = {}
def predict(rgb_image: str, depth_image: str, intrinsic: np.ndarray, num_samples: int) -> list[Any]:
"""
Run model on input image and generate output visualizations.
:param rgb_image: local path to RGB image file, used for model prediction and visualization
:param depth_image: local path to depth image file, used for visualization
:param intrinsic: array of dimension (3, 3) representing the intrinsic matrix of the camera
:param num_samples: number of visualization states to generate.
:return: list of updates to make to image components to visualize first image of visualization sequence, or
otherwise to hide an image component from visualization.
"""
global outputs
def find_images(path: str) -> dict[str, list[str]]:
"""Scrape folders for all generated image files."""
images = {}
for file in os.listdir(path):
sub_path = os.path.join(path, file)
if os.path.isdir(sub_path):
images[file] = []
for image_file in natsorted(os.listdir(sub_path)):
if re.match(r".*\.png$", image_file):
images[file].append(os.path.join(sub_path, image_file))
return images
# clear old predictions
# TODO: might be a better place for this than at the beginning of every invocation
os.makedirs(ARGS.output, exist_ok=True)
for path in os.listdir(ARGS.output):
full_path = os.path.join(ARGS.output, path)
if os.path.isdir(full_path):
shutil.rmtree(full_path)
else:
os.remove(full_path)
if not rgb_image:
gr.Error("You must provide an RGB image before running the model.")
return [None] * 5
if not depth_image:
gr.Error("You must provide a depth image before running the model.")
return [None] * 5
# run model
cfg = setup_cfg(ARGS)
engine.launch(
main,
NUM_PROCESSES,
args=(
cfg,
rgb_image,
depth_image,
intrinsic,
num_samples,
CROP,
SCORE_THRESHOLD,
),
)
# process output
# TODO: may want to select these in decreasing order of score
outputs[rgb_image] = []
image_files = find_images(ARGS.output)
for count, part in enumerate(image_files):
if count < MAX_PARTS: # only visualize up to MAX_PARTS parts
outputs[rgb_image].append([Image.open(im) for im in image_files[part]])
return [
*[gr.update(value=out[0], visible=True) for out in outputs[rgb_image]],
*[gr.update(visible=False) for _ in range(MAX_PARTS - len(outputs))],
]
def get_trigger(
idx: int, fps: int = 25, oscillate: bool = True
) -> Callable[[str], Generator[Image.Image, None, None]]:
"""
Return event listener trigger function for image component to animate image sequence.
:param idx: index of part to animate from output
:param fps: approximate rate at which images should be cycled through in frames per second. Note that the fps cannot
be higher than the rate at which images can be returned and rendered. Defaults to 40
:param oscillate: if True, animates part in reverse after running from start to end. Defaults to True
"""
def iter_images(rgb_image: str) -> Generator[Image.Image, None, None]:
"""Iterator to yield sequence of images for rendering, based on temp RGB image path"""
start_time = time.time()
def wait_until_next_frame(frame_count: int) -> None:
"""wait until appropriate time per the specified fps, relative to start time of iteration"""
time_to_sleep = max(frame_count / fps - (time.time() - start_time), 0)
if time_to_sleep <= 0:
print("[WARNING] frames cannot be rendered at the specified FPS due to processing/rendering time.")
time.sleep(time_to_sleep)
if not rgb_image or rgb_image not in outputs:
gr.Warning("You must upload an image and run the model before you can view the output.")
elif idx < len(outputs[rgb_image]):
frame_count = 0
# iterate forward
for im in outputs[rgb_image][idx]:
wait_until_next_frame(frame_count)
yield im
frame_count += 1
# iterate in reverse
if oscillate:
for im in reversed(outputs[rgb_image][idx]):
wait_until_next_frame(frame_count)
yield im
frame_count += 1
else:
gr.Error("Could not find any images to load into this module.")
return iter_images
def clear_outputs():
"""
Remove images from image components.
"""
return [gr.update(value=None, visible=(idx == 0)) for idx in range(MAX_PARTS)]
with gr.Blocks() as demo:
gr.Markdown(
"""
# OPDMulti Demo
We tackle the openable-part-detection (OPD) problem where we identify in a single-view image parts that are openable and their motion parameters. Our OPDFORMER architecture outputs segmentations for openable parts on potentially multiple objects, along with each part’s motion parameters: motion type (translation or rotation, indicated by blue or purple mask), motion axis and origin (see green arrows and points). For each openable part, we predict the motion parameters (axis and origin) in object coordinates along with an object pose prediction to convert to camera coordinates.
More information about the project, including code, can be found [here](https://3dlg-hcvc.github.io/OPDMulti/).
Upload an image to see a visualization of its range of motion below. Only the RGB image is needed for the model itself, but the depth image is required as of now for the visualization of motion.
If you know the intrinsic matrix of your camera, you can specify that here or otherwise use the default matrix which will work with any of the provided examples.
You can also change the number of samples to define the number of states in the visualization generated.
"""
)
# inputs
with gr.Row():
rgb_image = gr.Image(
image_mode="RGB", source="upload", type="filepath", label="RGB Image", show_label=True, interactive=True
)
depth_image = gr.Image(
image_mode="I;16", source="upload", type="filepath", label="Depth Image", show_label=True, interactive=True
)
intrinsic = gr.Dataframe(
value=[
[
214.85935872395834,
0.0,
125.90160319010417,
],
[
0.0,
214.85935872395834,
95.13726399739583,
],
[
0.0,
0.0,
1.0,
],
],
row_count=(3, "fixed"),
col_count=(3, "fixed"),
datatype="number",
type="numpy",
label="Intrinsic matrix",
show_label=True,
interactive=True,
)
num_samples = gr.Number(
value=NUM_SAMPLES,
label="Number of samples",
show_label=True,
interactive=True,
precision=0,
minimum=3,
maximum=20,
)
# specify examples which can be used to start
examples = gr.Examples(
examples=[
["examples/59-4860.png", "examples/59-4860_d.png"],
["examples/174-8460.png", "examples/174-8460_d.png"],
["examples/187-0.png", "examples/187-0_d.png"],
["examples/187-23040.png", "examples/187-23040_d.png"],
],
inputs=[rgb_image, depth_image],
api_name=False,
examples_per_page=2,
)
submit_btn = gr.Button("Run model")
# output
explanation = gr.Markdown(
value=f"# Output\nClick on an image to see an animation of the part motion. As of now, only up to {MAX_PARTS} parts can be visualized due to limitations of the visualizer."
)
images = [
gr.Image(type="pil", label=f"Part {idx + 1}", show_download_button=False, visible=(idx == 0))
for idx in range(MAX_PARTS)
]
for idx, image_comp in enumerate(images):
image_comp.select(get_trigger(idx), inputs=rgb_image, outputs=image_comp, api_name=False)
# if user changes input, clear output images
rgb_image.change(clear_outputs, inputs=rgb_image, outputs=images, api_name=False)
depth_image.change(clear_outputs, inputs=rgb_image, outputs=images, api_name=False)
submit_btn.click(
fn=predict, inputs=[rgb_image, depth_image, intrinsic, num_samples], outputs=images, api_name=False
)
demo.queue(api_open=False)
demo.launch()
|