EdgeCape / app.py
orhir's picture
Update app.py
77e5816 verified
import os
os.system('python setup.py develop')
import argparse
import json
from pathlib import Path
import gradio as gr
import matplotlib
from gradio_utils.utils import (process_img, get_select_coords, select_skeleton,
reset_skeleton, reset_kp, process, update_examples)
LENGTH = 480 # Length of the square area displaying/editing images
matplotlib.use('agg')
model_dir = Path('./checkpoints')
parser = argparse.ArgumentParser(description='EdgeCape Demo')
parser.add_argument('--checkpoint',
help='checkpoint path',
default='ckpt/1shot_split1.pth')
args = parser.parse_args()
checkpoint_path = args.checkpoint
device = 'cuda'
TIMEOUT = 80
with gr.Blocks() as demo:
gr.Markdown('''
# We introduce EdgeCape, a novel framework that overcomes these limitations by predicting the graph's edge weights which optimizes localization.
To further leverage structural priors, we propose integrating Markovian Structural Bias, which modulates the self-attention interaction between nodes based on the number of hops between them.
We show that this improves the model’s ability to capture global spatial dependencies.
Evaluated on the MP-100 benchmark, which includes 100 categories and over 20K images,
EdgeCape achieves state-of-the-art results in the 1-shot setting and leads among similar-sized methods in the 5-shot setting, significantly improving keypoint localization accuracy.
### [Paper](https://arxiv.org/pdf/2411.16665) | [Project Page](https://orhir.github.io/edge_cape/)
## Instructions
1. Upload an image from the same category as the object you want to pose.
2. Mark keypoints on the middle image. When finished - press 'Confirm Clicked Points'.
3. Mark limbs on the right image.
4. Upload an image of the object you want to pose to the query image (**bottom**).
5. Click **Evaluate** to pose the query image.
''')
global_state = gr.State({
"images": {},
"points": [],
"skeleton": [],
"prev_point": None,
"curr_type_point": "start",
"load_example": False,
})
with gr.Row():
# Upload & Preprocess Image Column
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 20px">Upload & Preprocess Image</p>"""
)
support_image = gr.Image(
height=LENGTH,
width=LENGTH,
type="pil",
image_mode="RGB",
label="Support Image",
show_label=True,
interactive=True,
)
# Click Points Column
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 20px">Click Points</p>"""
)
kp_support_image = gr.Image(
type="pil",
label="Keypoints Image",
show_label=True,
height=LENGTH,
width=LENGTH,
interactive=False,
show_fullscreen_button=False,
)
with gr.Row():
confirm_kp_button = gr.Button("Confirm Clicked Points", scale=3)
with gr.Row():
undo_kp_button = gr.Button("Undo Clicked Points", scale=3)
# Editing Results Column
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 20px">Click Skeleton</p>"""
)
skel_support_image = gr.Image(
type="pil",
label="Skeleton Image",
show_label=True,
height=LENGTH,
width=LENGTH,
interactive=False,
show_fullscreen_button=False,
)
with gr.Row():
pass
with gr.Row():
undo_skel_button = gr.Button("Undo Skeleton")
with gr.Row():
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 20px">Query Image</p>"""
)
query_image = gr.Image(
type="pil",
image_mode="RGB",
label="Query Image",
show_label=True,
interactive=True,
)
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 20px">Output</p>"""
)
output_img = gr.Plot(label="Output Image", )
with gr.Row():
eval_btn = gr.Button(value="Evaluate")
with gr.Row():
gr.Markdown("## Examples")
with gr.Row():
example_null = gr.Textbox(type='text',
visible=False
)
with gr.Row():
examples = gr.Examples([
['examples/dog2.png',
'examples/dog1.png',
json.dumps({
'points': [(232, 200), (312, 204), (228, 264), (316, 472), (316, 616), (296, 868), (412, 872),
(416, 624), (604, 608), (648, 860), (764, 852), (696, 608), (684, 432)],
'skeleton': [(0, 1), (1, 2), (0, 2), (3, 4), (4, 5),
(3, 7), (7, 6), (3, 12), (12, 8), (8, 9),
(12, 11), (11, 10)],
})
],
['examples/sofa1.jpg',
'examples/sofa2.png',
json.dumps({'points': [[272, 561], [193, 482], [339, 460], [445, 530], [264, 369], [203, 318], [354, 300],
[457, 341], [345, 63], [187, 68]],
'skeleton': [[0, 4], [1, 5], [2, 6], [3, 7], [7, 6], [6, 5],
[5, 4], [4, 7], [5, 9], [9, 8], [8, 6]],
})],
['examples/person1.jpeg',
'examples/person2.jpeg',
json.dumps({
'points': [[322, 488], [431, 486], [526, 644], [593, 486], [697, 492], [407, 728],
[522, 726], [625, 737], [515, 798]],
'skeleton': [[0, 1], [1, 3], [3, 4], [1, 2], [2, 3], [5, 6], [6, 7], [7, 8], [8, 5]],
})]
],
inputs=[support_image, query_image, example_null],
outputs=[support_image, kp_support_image, skel_support_image, query_image, global_state],
fn=update_examples,
run_on_click=True,
examples_per_page=5,
cache_examples=False,
)
support_image.upload(process_img,
inputs=[support_image, global_state],
outputs=[kp_support_image, global_state])
kp_support_image.select(get_select_coords,
[global_state],
[global_state, kp_support_image],
queue=False, )
confirm_kp_button.click(reset_skeleton,
inputs=global_state,
outputs=skel_support_image)
undo_kp_button.click(reset_kp,
inputs=global_state,
outputs=[kp_support_image, skel_support_image])
undo_skel_button.click(reset_skeleton,
inputs=global_state,
outputs=skel_support_image)
skel_support_image.select(select_skeleton,
inputs=[global_state],
outputs=[global_state, skel_support_image])
eval_btn.click(fn=process,
inputs=[query_image, global_state],
outputs=[output_img])
if __name__ == "__main__":
print("Start app", parser.parse_args())
gr.close_all()
demo.launch(show_api=False)