import os os.system('python setup.py develop') import argparse import json from pathlib import Path import gradio as gr import matplotlib from gradio_utils.utils import (process_img, get_select_coords, select_skeleton, reset_skeleton, reset_kp, process, update_examples) LENGTH = 480 # Length of the square area displaying/editing images matplotlib.use('agg') model_dir = Path('./checkpoints') parser = argparse.ArgumentParser(description='EdgeCape Demo') parser.add_argument('--checkpoint', help='checkpoint path', default='ckpt/1shot_split1.pth') args = parser.parse_args() checkpoint_path = args.checkpoint device = 'cuda' TIMEOUT = 80 with gr.Blocks() as demo: gr.Markdown(''' # We introduce EdgeCape, a novel framework that overcomes these limitations by predicting the graph's edge weights which optimizes localization. To further leverage structural priors, we propose integrating Markovian Structural Bias, which modulates the self-attention interaction between nodes based on the number of hops between them. We show that this improves the model’s ability to capture global spatial dependencies. Evaluated on the MP-100 benchmark, which includes 100 categories and over 20K images, EdgeCape achieves state-of-the-art results in the 1-shot setting and leads among similar-sized methods in the 5-shot setting, significantly improving keypoint localization accuracy. ### [Paper](https://arxiv.org/pdf/2411.16665) | [Project Page](https://orhir.github.io/edge_cape/) ## Instructions 1. Upload an image from the same category as the object you want to pose. 2. Mark keypoints on the middle image. When finished - press 'Confirm Clicked Points'. 3. Mark limbs on the right image. 4. Upload an image of the object you want to pose to the query image (**bottom**). 5. Click **Evaluate** to pose the query image. ''') global_state = gr.State({ "images": {}, "points": [], "skeleton": [], "prev_point": None, "curr_type_point": "start", "load_example": False, }) with gr.Row(): # Upload & Preprocess Image Column with gr.Column(): gr.Markdown( """
Upload & Preprocess Image
""" ) support_image = gr.Image( height=LENGTH, width=LENGTH, type="pil", image_mode="RGB", label="Support Image", show_label=True, interactive=True, ) # Click Points Column with gr.Column(): gr.Markdown( """Click Points
""" ) kp_support_image = gr.Image( type="pil", label="Keypoints Image", show_label=True, height=LENGTH, width=LENGTH, interactive=False, show_fullscreen_button=False, ) with gr.Row(): confirm_kp_button = gr.Button("Confirm Clicked Points", scale=3) with gr.Row(): undo_kp_button = gr.Button("Undo Clicked Points", scale=3) # Editing Results Column with gr.Column(): gr.Markdown( """Click Skeleton
""" ) skel_support_image = gr.Image( type="pil", label="Skeleton Image", show_label=True, height=LENGTH, width=LENGTH, interactive=False, show_fullscreen_button=False, ) with gr.Row(): pass with gr.Row(): undo_skel_button = gr.Button("Undo Skeleton") with gr.Row(): with gr.Column(): gr.Markdown( """Query Image
""" ) query_image = gr.Image( type="pil", image_mode="RGB", label="Query Image", show_label=True, interactive=True, ) with gr.Column(): gr.Markdown( """Output
""" ) output_img = gr.Plot(label="Output Image", ) with gr.Row(): eval_btn = gr.Button(value="Evaluate") with gr.Row(): gr.Markdown("## Examples") with gr.Row(): example_null = gr.Textbox(type='text', visible=False ) with gr.Row(): examples = gr.Examples([ ['examples/dog2.png', 'examples/dog1.png', json.dumps({ 'points': [(232, 200), (312, 204), (228, 264), (316, 472), (316, 616), (296, 868), (412, 872), (416, 624), (604, 608), (648, 860), (764, 852), (696, 608), (684, 432)], 'skeleton': [(0, 1), (1, 2), (0, 2), (3, 4), (4, 5), (3, 7), (7, 6), (3, 12), (12, 8), (8, 9), (12, 11), (11, 10)], }) ], ['examples/sofa1.jpg', 'examples/sofa2.png', json.dumps({'points': [[272, 561], [193, 482], [339, 460], [445, 530], [264, 369], [203, 318], [354, 300], [457, 341], [345, 63], [187, 68]], 'skeleton': [[0, 4], [1, 5], [2, 6], [3, 7], [7, 6], [6, 5], [5, 4], [4, 7], [5, 9], [9, 8], [8, 6]], })], ['examples/person1.jpeg', 'examples/person2.jpeg', json.dumps({ 'points': [[322, 488], [431, 486], [526, 644], [593, 486], [697, 492], [407, 728], [522, 726], [625, 737], [515, 798]], 'skeleton': [[0, 1], [1, 3], [3, 4], [1, 2], [2, 3], [5, 6], [6, 7], [7, 8], [8, 5]], })] ], inputs=[support_image, query_image, example_null], outputs=[support_image, kp_support_image, skel_support_image, query_image, global_state], fn=update_examples, run_on_click=True, examples_per_page=5, cache_examples=False, ) support_image.upload(process_img, inputs=[support_image, global_state], outputs=[kp_support_image, global_state]) kp_support_image.select(get_select_coords, [global_state], [global_state, kp_support_image], queue=False, ) confirm_kp_button.click(reset_skeleton, inputs=global_state, outputs=skel_support_image) undo_kp_button.click(reset_kp, inputs=global_state, outputs=[kp_support_image, skel_support_image]) undo_skel_button.click(reset_skeleton, inputs=global_state, outputs=skel_support_image) skel_support_image.select(select_skeleton, inputs=[global_state], outputs=[global_state, skel_support_image]) eval_btn.click(fn=process, inputs=[query_image, global_state], outputs=[output_img]) if __name__ == "__main__": print("Start app", parser.parse_args()) gr.close_all() demo.launch(show_api=False)