explore-label-concepts / src /label_interface.py
Xmaster6y's picture
votes not in metadata
de554eb unverified
raw
history blame
7.03 kB
"""Interface for labeling concepts in images.
"""
from typing import Optional
import random
import gradio as gr
from src import global_variables
from src.constants import CONCEPTS, ASSETS_FOLDER, DATASET_NAME
def filter_sample(sample, concepts, username, sample_type):
has_concepts = all([sample["concepts"].get(c, False) for c in concepts])
if not has_concepts:
return False
if "votes" in sample and username in sample["votes"]:
is_labelled = all([c in sample["votes"][username] for c in CONCEPTS])
else:
is_labelled = False
if sample_type == "labelled":
return is_labelled
elif sample_type == "unlabelled":
return not is_labelled
else:
raise ValueError(f"Invalid sample type: {sample_type}")
def get_next_image(
split: str,
concepts: list,
sample_type: str,
filtered_indices: dict,
selected_concepts: list,
selected_sample_type: str,
profile: gr.OAuthProfile
):
username = profile.username
if concepts != selected_concepts or sample_type != selected_sample_type:
for key, values in global_variables.all_metadata.items():
filtered_indices[key] = [i for i in range(len(values)) if filter_sample(values[i], concepts, username, sample_type)]
selected_concepts = concepts
selected_sample_type = sample_type
try:
sample_idx = random.choice(filtered_indices[split])
sample = global_variables.all_metadata[split][sample_idx]
image_path = f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/{sample['file_name']}"
try:
username_votes = global_variables.all_votes[sample["id"]][username]
voted_concepts = [c for c in CONCEPTS if username_votes.get(c, False)]
unseen_concepts = [c for c in CONCEPTS if c not in username_votes]
except KeyError:
voted_concepts = []
unseen_concepts = []
tie_concepts = [c for c in sample["concepts"] if sample["concepts"][c] is None]
return (
image_path,
voted_concepts,
f"{split}:{sample_idx}",
sample["class"],
sample["concepts"],
unseen_concepts,
tie_concepts,
filtered_indices,
selected_concepts,
selected_sample_type,
)
except IndexError:
gr.Warning("No image found for the selected filter.")
return None, None, None, None, None, None, None, filtered_indices, selected_concepts, selected_sample_type
def submit_label(
voted_concepts: list,
current_image: Optional[str],
split,
concepts,
sample_type,
filtered_indices,
selected_concepts,
selected_sample_type,
profile: gr.OAuthProfile
):
username = profile.username
if current_image is None:
gr.Warning("No image selected.")
return None, None, None, None, None, None, None, filtered_indices, selected_concepts, selected_sample_type
current_split, idx = current_image.split(":")
idx = int(idx)
global_variables.get_metadata(current_split)
s_id = global_variables.all_metadata[current_split][idx]["id"]
global_variables.get_votes(s_id)
if s_id not in global_variables.all_votes:
global_variables.all_votes[s_id] = {}
global_variables.all_votes[s_id][username] = {c: c in voted_concepts for c in CONCEPTS}
vote_sum = {c: 0 for c in CONCEPTS}
new_concepts = {}
for c in CONCEPTS:
for vote in global_variables.all_votes[s_id].values():
if c not in vote:
continue
vote_sum[c] += 2 * vote[c] - 1
new_concepts[c] = vote_sum[c] > 0 if vote_sum[c] != 0 else None
global_variables.all_metadata[current_split][idx]["concepts"] = new_concepts
global_variables.save_metadata(current_split)
global_variables.save_votes(s_id)
gr.Info("Submit success")
return get_next_image(
split,
concepts,
sample_type,
filtered_indices,
selected_concepts,
selected_sample_type,
profile
)
with gr.Blocks() as interface:
with gr.Row():
with gr.Column():
with gr.Group():
gr.Markdown(
"## # Image Selection",
)
with gr.Row():
split = gr.Radio(
label="Split",
choices=["train", "validation", "test"],
value="train",
)
sample_type = gr.Radio(
label="Sample Type",
choices=["labelled", "unlabelled"],
value="unlabelled",
)
concepts = gr.Dropdown(
label="Concepts",
multiselect=True,
choices=CONCEPTS,
)
with gr.Group():
voted_concepts = gr.CheckboxGroup(
label="Voted Concepts",
choices=CONCEPTS,
)
unseen_concepts = gr.CheckboxGroup(
label="Previously Unseen Concepts",
choices=CONCEPTS,
)
tie_concepts = gr.CheckboxGroup(
label="Tie Concepts",
choices=CONCEPTS,
)
with gr.Row():
next_button = gr.Button(
value="Next",
)
gr.LoginButton()
submit_button = gr.Button(
value="Submit",
)
with gr.Group():
gr.Markdown(
"## # Image Info",
)
im_class = gr.Textbox(
label="Class",
)
im_concepts = gr.JSON(
label="Concepts",
)
with gr.Column():
image = gr.Image(
label="Image",
)
current_image = gr.State(None)
filtered_indices = gr.State({
split: list(range(len(global_variables.all_metadata[split])))
for split in global_variables.all_metadata
})
selected_concepts = gr.State([])
selected_sample_type = gr.State(None)
common_output = [
image,
voted_concepts,
current_image,
im_class,
im_concepts,
unseen_concepts,
tie_concepts,
filtered_indices,
selected_concepts,
selected_sample_type,
]
next_button.click(
get_next_image,
inputs=[split, concepts, sample_type, filtered_indices, selected_concepts, selected_sample_type],
outputs=common_output
)
submit_button.click(
submit_label,
inputs=[voted_concepts, current_image, split, concepts, sample_type, filtered_indices, selected_concepts, selected_sample_type],
outputs=common_output
)