"""Interface for labeling concepts in images. """ from typing import Optional import random import gradio as gr from src import global_variables from src.constants import CONCEPTS, ASSETS_FOLDER, DATASET_NAME def get_next_image( split: str, concepts: list, filtered_indices: dict, selected_concepts: list, profile: gr.OAuthProfile ): username = profile.username if concepts != selected_concepts: for key, values in global_variables.all_metadata.items(): filtered_indices[key] = [i for i in range(len(values)) if all([values[i]["concepts"].get(c, False) for c in concepts])] selected_concepts = concepts try: sample_idx = random.choice(filtered_indices[split]) sample = global_variables.all_metadata[split][sample_idx] image_path = f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/{sample['file_name']}" try: username_votes = sample["votes"][username] voted_concepts = [c for c in CONCEPTS if username_votes.get(c, False)] except KeyError: voted_concepts = [] return ( image_path, voted_concepts, f"{split}:{sample_idx}", sample["class"], sample["concepts"], filtered_indices, selected_concepts, ) except IndexError: gr.Warning("No image found for the selected filter.") return None, None, None, None, None, filtered_indices, selected_concepts def submit_label( voted_concepts: list, current_image: Optional[str], split, concepts, filtered_indices, selected_concepts, profile: gr.OAuthProfile ): username = profile.username if current_image is None: gr.Warning("No image selected.") return None, None, None, None, None, filtered_indices, selected_concepts current_split, idx = current_image.split(":") idx = int(idx) global_variables.get_metadata(current_split) if "votes" not in global_variables.all_metadata[current_split][idx]: global_variables.all_metadata[current_split][idx]["votes"] = {} global_variables.all_metadata[current_split][idx]["votes"][username] = {c: c in voted_concepts for c in CONCEPTS} vote_sum = {c: 0 for c in CONCEPTS} concepts = {} for c in CONCEPTS: for vote in global_variables.all_metadata[current_split][idx]["votes"].values(): if c not in vote: continue vote_sum[c] += 2 * vote[c] - 1 concepts[c] = vote_sum[c] > 0 if vote_sum[c] != 0 else None global_variables.all_metadata[current_split][idx]["concepts"] = concepts global_variables.save_metadata(current_split) gr.Info("Submit success") return get_next_image( split, concepts, filtered_indices, selected_concepts, profile ) with gr.Blocks() as interface: with gr.Row(): with gr.Column(): with gr.Group(): gr.Markdown( "## # Image Selection", ) split = gr.Radio( label="Split", choices=["train", "validation", "test"], value="train", ) concepts = gr.Dropdown( label="Concepts", multiselect=True, choices=CONCEPTS, ) with gr.Group(): voted_concepts = gr.CheckboxGroup( label="Voted Concepts", choices=CONCEPTS, ) with gr.Row(): next_button = gr.Button( value="Next", ) gr.LoginButton() submit_button = gr.Button( value="Submit", ) with gr.Group(): gr.Markdown( "## # Image Info", ) im_class = gr.Textbox( label="Class", ) im_concepts = gr.JSON( label="Concepts", ) with gr.Column(): image = gr.Image( label="Image", ) current_image = gr.State(None) filtered_indices = gr.State({ split: list(range(len(global_variables.all_metadata[split]))) for split in global_variables.all_metadata }) selected_concepts = gr.State([]) common_output = [ image, voted_concepts, current_image, im_class, im_concepts, filtered_indices, selected_concepts ] next_button.click( get_next_image, inputs=[split, concepts, filtered_indices, selected_concepts], outputs=common_output ) submit_button.click( submit_label, inputs=[voted_concepts, current_image, split, concepts, filtered_indices, selected_concepts], outputs=common_output )