explore-label-concepts / src /label_interface.py
Xmaster6y's picture
inverted button order
7896579
"""Interface for labeling concepts in images.
"""
from typing import Optional
import random
import gradio as gr
from src import global_variables
from src.constants import CONCEPTS, ASSETS_FOLDER, DATASET_NAME
def filter_sample(sample, concepts, username, sample_type):
has_concepts = all([sample[c] for c in concepts])
if not has_concepts:
return False
if "votes" in sample and username in sample["votes"]:
is_labelled = all([c in sample["votes"][username] for c in CONCEPTS])
else:
is_labelled = False
if sample_type == "labelled":
return is_labelled
elif sample_type == "unlabelled":
return not is_labelled
else:
raise ValueError(f"Invalid sample type: {sample_type}")
def get_next_image(
split: str,
concepts: list,
sample_type: str,
filtered_indices: dict,
selected_concepts: list,
selected_sample_type: str,
profile: gr.OAuthProfile
):
username = profile.username
if concepts != selected_concepts or sample_type != selected_sample_type:
for key, values in global_variables.all_metadata.items():
filtered_indices[key] = [i for i in range(len(values)) if filter_sample(values[i], concepts, username, sample_type)]
selected_concepts = concepts
selected_sample_type = sample_type
try:
sample_idx = random.choice(filtered_indices[split])
sample = global_variables.all_metadata[split][sample_idx]
image_path = f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/{sample['file_name']}"
try:
username_votes = global_variables.all_votes[sample["id"]][username]
voted_concepts = [c for c in CONCEPTS if username_votes.get(c, False)]
unseen_concepts = [c for c in CONCEPTS if c not in username_votes]
except KeyError:
voted_concepts = []
unseen_concepts = []
tie_concepts = [c for c in CONCEPTS if sample[c] is None]
return (
image_path,
voted_concepts,
f"{split}:{sample_idx}",
sample["class"],
{c: sample[c] for c in CONCEPTS},
unseen_concepts,
tie_concepts,
filtered_indices,
selected_concepts,
selected_sample_type,
)
except IndexError:
gr.Warning("No image found for the selected filter.")
return None, None, None, None, None, None, None, filtered_indices, selected_concepts, selected_sample_type
def submit_label(
voted_concepts: list,
current_image: Optional[str],
split,
concepts,
sample_type,
filtered_indices,
selected_concepts,
selected_sample_type,
profile: gr.OAuthProfile
):
username = profile.username
if current_image is None:
gr.Warning("No image selected.")
return None, None, None, None, None, None, None, filtered_indices, selected_concepts, selected_sample_type
global_variables.update_votes(username, current_image, voted_concepts)
gr.Info("Submit success")
return get_next_image(
split,
concepts,
sample_type,
filtered_indices,
selected_concepts,
selected_sample_type,
profile
)
def save_current_work(
profile: gr.OAuthProfile,
):
username = profile.username
global_variables.save_current_work(username)
gr.Info("Save success")
with gr.Blocks() as interface:
with gr.Row():
with gr.Column():
with gr.Group():
gr.Markdown(
"## # Image Selection",
)
with gr.Row():
split = gr.Radio(
label="Split",
choices=["train", "test"],
value="train",
)
sample_type = gr.Radio(
label="Sample Type",
choices=["labelled", "unlabelled"],
value="unlabelled",
)
concepts = gr.Dropdown(
label="Concepts",
multiselect=True,
choices=CONCEPTS,
)
with gr.Row():
next_button = gr.Button(
value="Next",
)
gr.LoginButton()
submit_button = gr.Button(
value="Local Submit",
)
with gr.Row():
save_button = gr.Button(
value="Save",
)
with gr.Group():
voted_concepts = gr.CheckboxGroup(
label="Voted Concepts",
choices=CONCEPTS,
)
unseen_concepts = gr.CheckboxGroup(
label="Previously Unseen Concepts",
choices=CONCEPTS,
)
tie_concepts = gr.CheckboxGroup(
label="Tie Concepts",
choices=CONCEPTS,
)
with gr.Group():
gr.Markdown(
"## # Image Info",
)
im_class = gr.Textbox(
label="Class",
)
im_concepts = gr.JSON(
label="Concepts",
)
with gr.Column():
image = gr.Image(
label="Image",
)
current_image = gr.State(None)
filtered_indices = gr.State({
split: list(range(len(global_variables.all_metadata[split])))
for split in global_variables.all_metadata
})
selected_concepts = gr.State([])
selected_sample_type = gr.State(None)
common_output = [
image,
voted_concepts,
current_image,
im_class,
im_concepts,
unseen_concepts,
tie_concepts,
filtered_indices,
selected_concepts,
selected_sample_type,
]
next_button.click(
get_next_image,
inputs=[split, concepts, sample_type, filtered_indices, selected_concepts, selected_sample_type],
outputs=common_output
)
submit_button.click(
submit_label,
inputs=[voted_concepts, current_image, split, concepts, sample_type, filtered_indices, selected_concepts, selected_sample_type],
outputs=common_output
)
save_button.click(
save_current_work,
outputs=[image]
)