import csv import json import os import pickle import random import string import sys import time from glob import glob import datasets import gdown import gradio as gr import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch import torchvision from huggingface_hub import HfApi, login, snapshot_download from PIL import Image session_token = os.environ.get("SessionToken") login(token=session_token) csv.field_size_limit(sys.maxsize) np.random.seed(int(time.time())) with open("./imagenet_hard_nearest_indices.pkl", "rb") as f: knn_results = pickle.load(f) with open("imagenet-labels.json") as f: wnid_to_label = json.load(f) with open("id_to_label.json", "r") as f: id_to_labels = json.load(f) imagenet_training_samples_path = "imagenet_samples" bad_items = open("./ex2.txt", "r").read().split("\n") bad_items = [x.split(".")[0] for x in bad_items] bad_items = [int(x) for x in bad_items if x != ""] NUMBER_OF_IMAGES = len(bad_items) gdown.cached_download( url="https://huggingface.co/datasets/taesiri/imagenet_hard_review_samples/resolve/main/data.zip", path="./data.zip", quiet=False, md5="ece2720fed664e71799f316a881d4324", ) # EXTRACT if needed if not os.path.exists("./imagenet_samples") or not os.path.exists( "./knn_cache_for_imagenet_hard" ): torchvision.datasets.utils.extract_archive( from_path="data.zip", to_path="./", remove_finished=False, ) imagenet_hard = datasets.load_dataset("taesiri/imagenet-hard", split="validation") def update_snapshot(username): output_dir = snapshot_download( repo_id="taesiri/imagenet_hard_review_data_r2", allow_patterns="*.json", repo_type="dataset", ) files = glob(f"{output_dir}/*.json") df = pd.DataFrame() columns = ["id", "user_id", "time", "decision"] rows = [] for file in files: with open(file) as f: data = json.load(f) tdf = [data[x] for x in columns] rows.append(tdf) df = pd.DataFrame(rows, columns=columns) df = df[df["user_id"] == username] return df def generate_dataset(username): global NUMBER_OF_IMAGES df = update_snapshot(username) all_images = set(bad_items) answered = set(df.id) remaining = list(all_images - answered) # shuffle remaining random.shuffle(remaining) NUMBER_OF_IMAGES = len(bad_items) print(f"NUMBER_OF_IMAGES: {NUMBER_OF_IMAGES}") print(f"Remaining: {len(remaining)}") if NUMBER_OF_IMAGES == 0: return [] data = [] for i, image in enumerate(remaining): data.append( { "id": remaining[i], } ) return data def string_to_image(text): text = text.replace("_", " ").lower().replace(", ", "\n") # Create a blank white square image img = np.ones((220, 75, 3)) fig, ax = plt.subplots(figsize=(6, 2.25)) ax.imshow(img, extent=[0, 1, 0, 1]) ax.text(0.5, 0.75, text, fontsize=18, ha="center", va="center") ax.set_xticks([]) ax.set_yticks([]) ax.set_xticklabels([]) ax.set_yticklabels([]) for spine in ax.spines.values(): spine.set_visible(False) return fig all_samples = glob("./imagenet_samples/*.JPEG") qid_to_sample = { int(x.split("/")[-1].split(".")[0].split("_")[0]): x for x in all_samples } def get_training_samples(qid): labels_id = imagenet_hard[int(qid)]["label"] samples = [qid_to_sample[x] for x in labels_id] return samples def load_sample(data, current_index): image_id = data[current_index]["id"] qimage = imagenet_hard[int(image_id)]["image"] # labels = data[current_index]["correct_label"] labels = imagenet_hard[int(image_id)]["english_label"] # print(f"Image ID: {image_id}") # print(f"Labels: {labels}") return qimage, labels def preprocessing(data, current_index, history, username): data = generate_dataset(username) remaining_images = len(data) labeled_images = len(bad_items) - remaining_images if remaining_images == 0: fake_plot = string_to_image("No more images to review") empty_image = Image.new("RGB", (224, 224)) return ( empty_image, fake_plot, current_index, history, data, None, labeled_images, ) current_index = 0 qimage, labels = load_sample(data, current_index) image_id = data[current_index]["id"] training_samples_image = get_training_samples(image_id) training_samples_image = [ Image.open(x).convert("RGB") for x in training_samples_image ] # labels is a list of labels, conver it to a string labels = ", ".join(labels) label_plot = string_to_image(labels) return ( qimage, label_plot, current_index, history, data, training_samples_image, labeled_images, ) def update_app(decision, data, current_index, history, username): global NUMBER_OF_IMAGES if current_index == -1: fake_plot = string_to_image("Please Enter your username and load samples") empty_image = Image.new("RGB", (224, 224)) return empty_image, fake_plot, current_index, history, data, None, 0 if current_index == NUMBER_OF_IMAGES - 1: time_stamp = int(time.time()) image_id = data[current_index]["id"] # convert to percentage dicision_dict = { "id": int(image_id), "user_id": username, "time": time_stamp, "decision": decision, } # upload the decision to the server temp_filename = f"results_{username}_{time_stamp}.json" # convert decision_dict to json and save it on the disk with open(temp_filename, "w") as f: json.dump(dicision_dict, f) api = HfApi() api.upload_file( path_or_fileobj=temp_filename, path_in_repo=temp_filename, repo_id="taesiri/imagenet_hard_review_data_r2", repo_type="dataset", ) os.remove(temp_filename) fake_plot = string_to_image("Thank you for your time!") empty_image = Image.new("RGB", (224, 224)) remaining_images = len(data) labeled_images = (len(bad_items) - remaining_images) + current_index return ( empty_image, fake_plot, current_index, history, data, None, labeled_images + 1, ) if current_index >= 0 and current_index < NUMBER_OF_IMAGES - 1: time_stamp = int(time.time()) image_id = data[current_index]["id"] # convert to percentage dicision_dict = { "id": int(image_id), "user_id": username, "time": time_stamp, "decision": decision, } # upload the decision to the server temp_filename = f"results_{username}_{time_stamp}.json" # convert decision_dict to json and save it on the disk with open(temp_filename, "w") as f: json.dump(dicision_dict, f) api = HfApi() api.upload_file( path_or_fileobj=temp_filename, path_in_repo=temp_filename, repo_id="taesiri/imagenet_hard_review_data", repo_type="dataset", ) os.remove(temp_filename) # Load the Next Image current_index += 1 qimage, labels = load_sample(data, current_index) image_id = data[current_index]["id"] training_samples_image = get_training_samples(image_id) training_samples_image = [ Image.open(x).convert("RGB") for x in training_samples_image ] # labels is a list of labels, conver it to a string labels = ", ".join(labels) label_plot = string_to_image(labels) remaining_images = len(data) labeled_images = (len(bad_items) - remaining_images) + current_index return ( qimage, label_plot, current_index, history, data, training_samples_image, labeled_images, ) newcss = """ #query_image{ } #nn_gallery { height: auto !important; } #sample_gallery { height: auto !important; } /* Set display to flex for the parent element */ .svelte-parentrowclass { display: flex; } /* Set the flex-grow property for the children elements */ .svelte-parentrowclass > #query_image { min-width: min(400px, 40%); flex : 1; flex-grow: 0; !important; border-style: solid; height: auto !important; } .svelte-parentrowclass > .svelte-rightcolumn { flex: 2; flex-grow: 0; !important; min-width: min(600px, 60%); } """ with gr.Blocks(css=newcss, theme=gr.themes.Soft()) as demo: data_gr = gr.State({}) current_index = gr.State(-1) history = gr.State({}) gr.Markdown("# Help Us to Clean `ImageNet-Hard`!") gr.Markdown("## Instructions") gr.Markdown( "Please enter your username and press `Load Samples`. The loading process might take up to a minute. Once the loading is done, you can start reviewing the samples." ) gr.Markdown( """For each image, please select one of the following options: `Accept`, `Not Sure!`, `Reject`. - If you think any of the labels are correct, please select `Accept`. - If you think none of the labels matching the image, please select `Reject`. - If you are not sure about the label, please select `Not Sure!`. You can refer to `Training samples` if you are not sure about the target label. """ ) random_str = "".join( random.choice(string.ascii_lowercase + string.digits) for _ in range(5) ) with gr.Column(): with gr.Row(): username = gr.Textbox(label="Username", value=f"user-{random_str}") labeled_images = gr.Textbox(label="Labeled Images", value="0") total_images = gr.Textbox(label="Total Images", value=len(bad_items)) prepare_btn = gr.Button(value="Load Samples") with gr.Column(): with gr.Row(): accept_btn = gr.Button(value="Accept") myabe_btn = gr.Button(value="Not Sure!") reject_btn = gr.Button(value="Reject") with gr.Row(elem_id="parent_row", elem_classes="svelte-parentrowclass"): query_image = gr.Image(type="pil", label="Query", elem_id="query_image") with gr.Column( elem_id="samples_col", elem_classes="svelte-rightcolumn", ): label_plot = gr.Plot( label="Is this a correct label for this image?", type="fig" ) training_samples = gr.Gallery( type="pil", label="Training samples", elem_id="sample_gallery" ) accept_btn.click( update_app, inputs=[accept_btn, data_gr, current_index, history, username], outputs=[ query_image, label_plot, current_index, history, data_gr, training_samples, labeled_images, ], ) myabe_btn.click( update_app, inputs=[myabe_btn, data_gr, current_index, history, username], outputs=[ query_image, label_plot, current_index, history, data_gr, training_samples, labeled_images, ], ) reject_btn.click( update_app, inputs=[reject_btn, data_gr, current_index, history, username], outputs=[ query_image, label_plot, current_index, history, data_gr, training_samples, labeled_images, ], ) prepare_btn.click( preprocessing, inputs=[data_gr, current_index, history, username], outputs=[ query_image, label_plot, current_index, history, data_gr, training_samples, labeled_images, ], ) demo.launch(debug=False)