Spaces:
Runtime error
Runtime error
"""Interface for labeling concepts in images. | |
""" | |
from typing import Optional | |
import random | |
import gradio as gr | |
from src import global_variables | |
from src.constants import CONCEPTS, ASSETS_FOLDER, DATASET_NAME | |
def filter_sample(sample, concepts, username, sample_type): | |
has_concepts = all([sample["concepts"].get(c, False) for c in concepts]) | |
if not has_concepts: | |
return False | |
if "votes" in sample and username in sample["votes"]: | |
is_labelled = all([c in sample["votes"][username] for c in CONCEPTS]) | |
else: | |
is_labelled = False | |
if sample_type == "labelled": | |
return is_labelled | |
elif sample_type == "unlabelled": | |
return not is_labelled | |
else: | |
raise ValueError(f"Invalid sample type: {sample_type}") | |
def get_next_image( | |
split: str, | |
concepts: list, | |
sample_type: str, | |
filtered_indices: dict, | |
selected_concepts: list, | |
selected_sample_type: str, | |
profile: gr.OAuthProfile | |
): | |
username = profile.username | |
if concepts != selected_concepts or sample_type != selected_sample_type: | |
for key, values in global_variables.all_metadata.items(): | |
filtered_indices[key] = [i for i in range(len(values)) if filter_sample(values[i], concepts, username, sample_type)] | |
selected_concepts = concepts | |
selected_sample_type = sample_type | |
try: | |
sample_idx = random.choice(filtered_indices[split]) | |
sample = global_variables.all_metadata[split][sample_idx] | |
image_path = f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/{sample['file_name']}" | |
try: | |
username_votes = sample["votes"][username] | |
voted_concepts = [c for c in CONCEPTS if username_votes.get(c, False)] | |
unseen_concepts = [c for c in CONCEPTS if c not in username_votes] | |
except KeyError: | |
voted_concepts = [] | |
unseen_concepts = [] | |
tie_concepts = [c for c in sample["concepts"] if sample["concepts"][c] is None] | |
return ( | |
image_path, | |
voted_concepts, | |
f"{split}:{sample_idx}", | |
sample["class"], | |
sample["concepts"], | |
unseen_concepts, | |
tie_concepts, | |
filtered_indices, | |
selected_concepts, | |
selected_sample_type, | |
) | |
except IndexError: | |
gr.Warning("No image found for the selected filter.") | |
return None, None, None, None, None, None, None, None, filtered_indices, selected_concepts, selected_sample_type | |
def submit_label( | |
voted_concepts: list, | |
current_image: Optional[str], | |
split, | |
concepts, | |
sample_type, | |
filtered_indices, | |
selected_concepts, | |
selected_sample_type, | |
profile: gr.OAuthProfile | |
): | |
username = profile.username | |
if current_image is None: | |
gr.Warning("No image selected.") | |
return None, None, None, None, None, None, None, None, filtered_indices, selected_concepts, selected_sample_type | |
current_split, idx = current_image.split(":") | |
idx = int(idx) | |
global_variables.get_metadata(current_split) | |
if "votes" not in global_variables.all_metadata[current_split][idx]: | |
global_variables.all_metadata[current_split][idx]["votes"] = {} | |
global_variables.all_metadata[current_split][idx]["votes"][username] = {c: c in voted_concepts for c in CONCEPTS} | |
vote_sum = {c: 0 for c in CONCEPTS} | |
new_concepts = {} | |
for c in CONCEPTS: | |
for vote in global_variables.all_metadata[current_split][idx]["votes"].values(): | |
if c not in vote: | |
continue | |
vote_sum[c] += 2 * vote[c] - 1 | |
new_concepts[c] = vote_sum[c] > 0 if vote_sum[c] != 0 else None | |
global_variables.all_metadata[current_split][idx]["concepts"] = new_concepts | |
global_variables.save_metadata(current_split) | |
gr.Info("Submit success") | |
return get_next_image( | |
split, | |
concepts, | |
sample_type, | |
filtered_indices, | |
selected_concepts, | |
selected_sample_type, | |
profile | |
) | |
with gr.Blocks() as interface: | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(): | |
gr.Markdown( | |
"## # Image Selection", | |
) | |
with gr.Row(): | |
split = gr.Radio( | |
label="Split", | |
choices=["train", "validation", "test"], | |
value="train", | |
) | |
sample_type = gr.Radio( | |
label="Sample Type", | |
choices=["labelled", "unlabelled"], | |
value="unlabelled", | |
) | |
concepts = gr.Dropdown( | |
label="Concepts", | |
multiselect=True, | |
choices=CONCEPTS, | |
) | |
with gr.Group(): | |
voted_concepts = gr.CheckboxGroup( | |
label="Voted Concepts", | |
choices=CONCEPTS, | |
) | |
unseen_concepts = gr.CheckboxGroup( | |
label="Previously Unseen Concepts", | |
choices=CONCEPTS, | |
) | |
tie_concepts = gr.CheckboxGroup( | |
label="Tie Concepts", | |
choices=CONCEPTS, | |
) | |
with gr.Row(): | |
next_button = gr.Button( | |
value="Next", | |
) | |
gr.LoginButton() | |
submit_button = gr.Button( | |
value="Submit", | |
) | |
with gr.Group(): | |
gr.Markdown( | |
"## # Image Info", | |
) | |
im_class = gr.Textbox( | |
label="Class", | |
) | |
im_concepts = gr.JSON( | |
label="Concepts", | |
) | |
with gr.Column(): | |
image = gr.Image( | |
label="Image", | |
) | |
current_image = gr.State(None) | |
filtered_indices = gr.State({ | |
split: list(range(len(global_variables.all_metadata[split]))) | |
for split in global_variables.all_metadata | |
}) | |
selected_concepts = gr.State([]) | |
selected_sample_type = gr.State(None) | |
common_output = [ | |
image, | |
voted_concepts, | |
current_image, | |
im_class, | |
im_concepts, | |
unseen_concepts, | |
tie_concepts, | |
filtered_indices, | |
selected_concepts, | |
selected_sample_type, | |
] | |
next_button.click( | |
get_next_image, | |
inputs=[split, concepts, sample_type, filtered_indices, selected_concepts, selected_sample_type], | |
outputs=common_output | |
) | |
submit_button.click( | |
submit_label, | |
inputs=[voted_concepts, current_image, split, concepts, sample_type, filtered_indices, selected_concepts, selected_sample_type], | |
outputs=common_output | |
) | |