Spaces:
Paused
Paused
import csv | |
import json | |
import os | |
import pickle | |
import random | |
import string | |
import sys | |
import time | |
from glob import glob | |
import datasets | |
import gdown | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torchvision | |
from huggingface_hub import HfApi, login, snapshot_download | |
from PIL import Image | |
import re | |
from fnmatch import translate | |
session_token = os.environ.get("SessionToken") | |
login(token=session_token) | |
csv.field_size_limit(sys.maxsize) | |
np.random.seed(int(time.time())) | |
with open("./imagenet_hard_nearest_indices.pkl", "rb") as f: | |
knn_results = pickle.load(f) | |
with open("imagenet-labels.json") as f: | |
wnid_to_label = json.load(f) | |
with open("id_to_label.json", "r") as f: | |
id_to_labels = json.load(f) | |
imagenet_training_samples_path = "imagenet_samples" | |
bad_items = open("./ex2.txt", "r").read().split("\n") | |
bad_items = [x.split(".")[0] for x in bad_items] | |
bad_items = [int(x) for x in bad_items if x != ""] | |
NUMBER_OF_IMAGES = len(bad_items) | |
gdown.cached_download( | |
url="https://huggingface.co/datasets/taesiri/imagenet_hard_review_samples/resolve/main/data.zip", | |
path="./data.zip", | |
quiet=False, | |
md5="ece2720fed664e71799f316a881d4324", | |
) | |
# EXTRACT if needed | |
if not os.path.exists("./imagenet_samples") or not os.path.exists( | |
"./knn_cache_for_imagenet_hard" | |
): | |
torchvision.datasets.utils.extract_archive( | |
from_path="data.zip", | |
to_path="./", | |
remove_finished=False, | |
) | |
imagenet_hard = datasets.load_dataset("taesiri/imagenet-hard", split="validation") | |
def update_snapshot(username): | |
escaped_username = re.escape(username) | |
pattern = f"*{escaped_username}*.json" | |
output_dir = snapshot_download( | |
repo_id="taesiri/imagenet_hard_review_data_r2", | |
allow_patterns=translate(pattern), | |
repo_type="dataset", | |
) | |
files = glob(f"{output_dir}/*.json") | |
df = pd.DataFrame() | |
columns = ["id", "user_id", "time", "decision"] | |
rows = [] | |
for file in files: | |
with open(file) as f: | |
data = json.load(f) | |
tdf = [data[x] for x in columns] | |
rows.append(tdf) | |
df = pd.DataFrame(rows, columns=columns) | |
# download and append all CSV files | |
output_dir = snapshot_download( | |
repo_id="taesiri/imagenet_hard_review_data_r3", | |
allow_patterns="*.csv", | |
repo_type="dataset", | |
) | |
files = glob(f"{output_dir}/*.csv") | |
if len(files) > 0: | |
csv_dataframes = [pd.read_csv(file) for file in files] | |
csv_dataframes = pd.concat(csv_dataframes, ignore_index=True) | |
df = pd.concat([df, csv_dataframes], ignore_index=True) | |
# remove duplicate rows | |
df = df.drop_duplicates(subset=["id", "user_id"], keep="last") | |
df = df[df["user_id"] == username] | |
return df | |
def generate_dataset(username): | |
global NUMBER_OF_IMAGES | |
df = update_snapshot(username) | |
all_images = set(bad_items) | |
answered = set(df.id) | |
remaining = list(all_images - answered) | |
# shuffle remaining | |
random.shuffle(remaining) | |
NUMBER_OF_IMAGES = len(bad_items) | |
print(f"NUMBER_OF_IMAGES: {NUMBER_OF_IMAGES}") | |
print(f"Remaining: {len(remaining)}") | |
if NUMBER_OF_IMAGES == 0: | |
return [] | |
data = [] | |
for i, image in enumerate(remaining): | |
data.append( | |
{ | |
"id": remaining[i], | |
} | |
) | |
return data | |
def string_to_image(text): | |
text = text.replace("_", " ").lower().replace(", ", "\n") | |
# Create a blank white square image | |
img = np.ones((220, 75, 3)) | |
fig, ax = plt.subplots(figsize=(6, 2.25)) | |
ax.imshow(img, extent=[0, 1, 0, 1]) | |
ax.text(0.5, 0.75, text, fontsize=18, ha="center", va="center") | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
ax.set_xticklabels([]) | |
ax.set_yticklabels([]) | |
for spine in ax.spines.values(): | |
spine.set_visible(False) | |
return fig | |
all_samples = glob("./imagenet_samples/*.JPEG") | |
qid_to_sample = { | |
int(x.split("/")[-1].split(".")[0].split("_")[0]): x for x in all_samples | |
} | |
def get_training_samples(qid): | |
labels_id = imagenet_hard[int(qid)]["label"] | |
samples = [qid_to_sample[x] for x in labels_id] | |
return samples | |
def load_sample(data, current_index): | |
image_id = data[current_index]["id"] | |
qimage = imagenet_hard[int(image_id)]["image"] | |
# labels = data[current_index]["correct_label"] | |
labels = imagenet_hard[int(image_id)]["english_label"] | |
# print(f"Image ID: {image_id}") | |
# print(f"Labels: {labels}") | |
return qimage, labels | |
def preprocessing(data, current_index, history, username): | |
data = generate_dataset(username) | |
remaining_images = len(data) | |
labeled_images = len(bad_items) - remaining_images | |
if remaining_images == 0: | |
fake_plot = string_to_image("No more images to review") | |
empty_image = Image.new("RGB", (224, 224)) | |
return ( | |
empty_image, | |
fake_plot, | |
current_index, | |
history, | |
data, | |
None, | |
labeled_images, | |
) | |
current_index = 0 | |
qimage, labels = load_sample(data, current_index) | |
image_id = data[current_index]["id"] | |
training_samples_image = get_training_samples(image_id) | |
training_samples_image = [ | |
Image.open(x).convert("RGB") for x in training_samples_image | |
] | |
# labels is a list of labels, conver it to a string | |
labels = ", ".join(labels) | |
label_plot = string_to_image(labels) | |
return ( | |
qimage, | |
label_plot, | |
current_index, | |
history, | |
data, | |
training_samples_image, | |
labeled_images, | |
) | |
def update_app(decision, data, current_index, history, username): | |
global NUMBER_OF_IMAGES | |
if current_index == -1: | |
fake_plot = string_to_image("Please Enter your username and load samples") | |
empty_image = Image.new("RGB", (224, 224)) | |
return empty_image, fake_plot, current_index, history, data, None, 0 | |
if current_index == NUMBER_OF_IMAGES - 1: | |
time_stamp = int(time.time()) | |
image_id = data[current_index]["id"] | |
# convert to percentage | |
dicision_dict = { | |
"id": int(image_id), | |
"user_id": username, | |
"time": time_stamp, | |
"decision": decision, | |
} | |
# upload the decision to the server | |
temp_filename = f"results_{username}_{time_stamp}.json" | |
# convert decision_dict to json and save it on the disk | |
with open(temp_filename, "w") as f: | |
json.dump(dicision_dict, f) | |
api = HfApi() | |
api.upload_file( | |
path_or_fileobj=temp_filename, | |
path_in_repo=temp_filename, | |
repo_id="taesiri/imagenet_hard_review_data_r2", | |
repo_type="dataset", | |
) | |
os.remove(temp_filename) | |
fake_plot = string_to_image("Thank you for your time!") | |
empty_image = Image.new("RGB", (224, 224)) | |
remaining_images = len(data) | |
labeled_images = (len(bad_items) - remaining_images) + current_index | |
return ( | |
empty_image, | |
fake_plot, | |
current_index, | |
history, | |
data, | |
None, | |
labeled_images + 1, | |
) | |
if current_index >= 0 and current_index < NUMBER_OF_IMAGES - 1: | |
time_stamp = int(time.time()) | |
image_id = data[current_index]["id"] | |
# convert to percentage | |
dicision_dict = { | |
"id": int(image_id), | |
"user_id": username, | |
"time": time_stamp, | |
"decision": decision, | |
} | |
# upload the decision to the server | |
temp_filename = f"results_{username}_{time_stamp}.json" | |
# convert decision_dict to json and save it on the disk | |
with open(temp_filename, "w") as f: | |
json.dump(dicision_dict, f) | |
api = HfApi() | |
api.upload_file( | |
path_or_fileobj=temp_filename, | |
path_in_repo=temp_filename, | |
repo_id="taesiri/imagenet_hard_review_data_r2", | |
repo_type="dataset", | |
) | |
os.remove(temp_filename) | |
# Load the Next Image | |
current_index += 1 | |
qimage, labels = load_sample(data, current_index) | |
image_id = data[current_index]["id"] | |
training_samples_image = get_training_samples(image_id) | |
training_samples_image = [ | |
Image.open(x).convert("RGB") for x in training_samples_image | |
] | |
# labels is a list of labels, conver it to a string | |
labels = ", ".join(labels) | |
label_plot = string_to_image(labels) | |
remaining_images = len(data) | |
labeled_images = (len(bad_items) - remaining_images) + current_index | |
return ( | |
qimage, | |
label_plot, | |
current_index, | |
history, | |
data, | |
training_samples_image, | |
labeled_images, | |
) | |
newcss = """ | |
#query_image{ | |
} | |
#nn_gallery { | |
height: auto !important; | |
} | |
#sample_gallery { | |
height: auto !important; | |
} | |
/* Set display to flex for the parent element */ | |
.svelte-parentrowclass { | |
display: flex; | |
} | |
/* Set the flex-grow property for the children elements */ | |
.svelte-parentrowclass > #query_image { | |
min-width: min(400px, 40%); | |
flex : 1; | |
flex-grow: 0; !important; | |
border-style: solid; | |
height: auto !important; | |
} | |
.svelte-parentrowclass > .svelte-rightcolumn { | |
flex: 2; | |
flex-grow: 0; !important; | |
min-width: min(600px, 60%); | |
} | |
""" | |
with gr.Blocks(css=newcss, theme=gr.themes.Soft()) as demo: | |
data_gr = gr.State({}) | |
current_index = gr.State(-1) | |
history = gr.State({}) | |
gr.Markdown("# Help Us to Clean `ImageNet-Hard`!") | |
gr.Markdown("## Instructions") | |
gr.Markdown( | |
"Please enter your username and press `Load Samples`. The loading process might take up to a minute. Once the loading is done, you can start reviewing the samples." | |
) | |
gr.Markdown( | |
"""For each image, please select one of the following options: `Accept`, `Not Sure!`, `Reject`. | |
- If you think any of the labels are correct, please select `Accept`. | |
- If you think none of the labels matching the image, please select `Reject`. | |
- If you are not sure about the label, please select `Not Sure!`. | |
You can refer to `Training samples` if you are not sure about the target label. | |
""" | |
) | |
random_str = "".join( | |
random.choice(string.ascii_lowercase + string.digits) for _ in range(5) | |
) | |
with gr.Column(): | |
with gr.Row(): | |
username = gr.Textbox(label="Username", value=f"user-{random_str}") | |
labeled_images = gr.Textbox(label="Labeled Images", value="0") | |
total_images = gr.Textbox(label="Total Images", value=len(bad_items)) | |
prepare_btn = gr.Button(value="Load Samples") | |
with gr.Column(): | |
with gr.Row(): | |
accept_btn = gr.Button(value="Accept") | |
myabe_btn = gr.Button(value="Not Sure!") | |
reject_btn = gr.Button(value="Reject") | |
with gr.Row(elem_id="parent_row", elem_classes="svelte-parentrowclass"): | |
query_image = gr.Image(type="pil", label="Query", elem_id="query_image") | |
with gr.Column( | |
elem_id="samples_col", | |
elem_classes="svelte-rightcolumn", | |
): | |
label_plot = gr.Plot( | |
label="Is this a correct label for this image?", type="fig" | |
) | |
training_samples = gr.Gallery( | |
type="pil", label="Training samples", elem_id="sample_gallery" | |
) | |
accept_btn.click( | |
update_app, | |
inputs=[accept_btn, data_gr, current_index, history, username], | |
outputs=[ | |
query_image, | |
label_plot, | |
current_index, | |
history, | |
data_gr, | |
training_samples, | |
labeled_images, | |
], | |
) | |
myabe_btn.click( | |
update_app, | |
inputs=[myabe_btn, data_gr, current_index, history, username], | |
outputs=[ | |
query_image, | |
label_plot, | |
current_index, | |
history, | |
data_gr, | |
training_samples, | |
labeled_images, | |
], | |
) | |
reject_btn.click( | |
update_app, | |
inputs=[reject_btn, data_gr, current_index, history, username], | |
outputs=[ | |
query_image, | |
label_plot, | |
current_index, | |
history, | |
data_gr, | |
training_samples, | |
labeled_images, | |
], | |
) | |
prepare_btn.click( | |
preprocessing, | |
inputs=[data_gr, current_index, history, username], | |
outputs=[ | |
query_image, | |
label_plot, | |
current_index, | |
history, | |
data_gr, | |
training_samples, | |
labeled_images, | |
], | |
) | |
demo.launch(debug=False) | |