"""Interface for labeling concepts in images. """ from typing import Optional import gradio as gr from src import global_variables from src.constants import CONCEPTS, ASSETS_FOLDER, DATASET_NAME def get_image( step: int, split: str, index: str, filtered_indices: dict, profile: gr.OAuthProfile ): username = profile.username try: int_index = int(index) except: gr.Warning("Error parsing index using 0") int_index = 0 sample_idx = int_index + step if sample_idx < 0: gr.Warning("No previous image.") sample_idx = 0 if sample_idx >= len(global_variables.all_metadata[split]): gr.Warning("No next image.") sample_idx = len(global_variables.all_metadata[split]) - 1 sample = global_variables.all_metadata[split][sample_idx] image_path = f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/{sample['file_name']}" try: username_votes = global_variables.all_votes[sample["id"]][username] voted_concepts = [c for c in CONCEPTS if username_votes.get(c, False)] unseen_concepts = [c for c in CONCEPTS if c not in username_votes] except KeyError: voted_concepts = [] unseen_concepts = [] tie_concepts = [c for c in CONCEPTS if sample[c] is None] return ( image_path, voted_concepts, f"{split}:{sample_idx}", sample["class"], {c: sample[c] for c in CONCEPTS}, str(sample_idx), unseen_concepts, tie_concepts, filtered_indices, ) def make_get_image(step): def f( split: str, index: str, filtered_indices: dict, profile: gr.OAuthProfile ): return get_image(step, split, index, filtered_indices, profile) return f get_next_image = make_get_image(1) get_prev_image = make_get_image(-1) get_current_image = make_get_image(0) def submit_label( voted_concepts: list, current_image: Optional[str], split, index, filtered_indices, profile: gr.OAuthProfile ): username = profile.username if current_image is None: gr.Warning("No image selected.") return None, None, None, None, None, None, None, index, filtered_indices global_variables.update_votes(username, current_image, voted_concepts) gr.Info("Submit success") return get_next_image( split, index, filtered_indices, profile ) def save_current_work( profile: gr.OAuthProfile, ): username = profile.username global_variables.save_current_work(username) gr.Info("Save success") with gr.Blocks() as interface: with gr.Row(): with gr.Column(): with gr.Group(): gr.Markdown( "## # Image Selection", ) split = gr.Radio( label="Split", choices=["train", "test"], value="train", ) index = gr.Textbox( value="0", label="Index", max_lines=1, ) with gr.Row(): prev_button = gr.Button( value="Prev", ) next_button = gr.Button( value="Next", ) gr.LoginButton() submit_button = gr.Button( value="Local Submit", ) with gr.Row(): save_button = gr.Button( value="Save", ) with gr.Group(): voted_concepts = gr.CheckboxGroup( label="Voted Concepts", choices=CONCEPTS, ) unseen_concepts = gr.CheckboxGroup( label="Previously Unseen Concepts", choices=CONCEPTS, ) tie_concepts = gr.CheckboxGroup( label="Tie Concepts", choices=CONCEPTS, ) with gr.Group(): gr.Markdown( "## # Image Info", ) im_class = gr.Textbox( label="Class", ) im_concepts = gr.JSON( label="Concepts", ) with gr.Column(): image = gr.Image( label="Image", ) current_image = gr.State(None) filtered_indices = gr.State({ split: list(range(len(global_variables.all_metadata[split]))) for split in global_variables.all_metadata }) common_output = [ image, voted_concepts, current_image, im_class, im_concepts, index, unseen_concepts, tie_concepts, filtered_indices, ] common_input = [split, index, filtered_indices] prev_button.click( get_prev_image, inputs=common_input, outputs=common_output ) next_button.click( get_next_image, inputs=common_input, outputs=common_output ) submit_button.click( submit_label, inputs=[voted_concepts, current_image, split, index, filtered_indices], outputs=common_output ) index.submit( get_current_image, inputs=common_input, outputs=common_output, ) save_button.click( save_current_work, outputs=[image] )