Spaces:
Paused
Paused
import json | |
import os | |
import random | |
import string | |
import time | |
import sys | |
import datasets | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import pickle | |
from PIL import Image | |
from torchvision import transforms | |
from huggingface_hub import HfApi, login | |
from torchvision.datasets import ImageFolder | |
from glob import glob | |
import gdown | |
import torchvision | |
import pandas as pd | |
from huggingface_hub import HfApi, login, snapshot_download | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import csv | |
csv.field_size_limit(sys.maxsize) | |
np.random.seed(int(time.time())) | |
with open('./imagenet_hard_nearest_indices.pkl', 'rb') as f: | |
knn_results = pickle.load(f) | |
with open("imagenet-labels.json") as f: | |
wnid_to_label = json.load(f) | |
with open('id_to_label.json', 'r') as f: | |
id_to_labels = json.load(f) | |
bad_items = open('./ex2.txt', 'r').read().split('\n') | |
bad_items = [x.split('.')[0] for x in bad_items] | |
bad_items = [int(x) for x in bad_items if x != ''] | |
# download and extract folders | |
gdown.cached_download( | |
url="https://huggingface.co/datasets/taesiri/imagenet_hard_review_samples/resolve/main/data.zip", | |
path="./data.zip", | |
quiet=False, | |
md5="8666a9b361f6eea79878be6c09701def", | |
) | |
# EXTRACT if needed | |
if not os.path.exists("./imagenet_traning_samples") or not os.path.exists("./knn_cache_for_imagenet_hard"): | |
torchvision.datasets.utils.extract_archive( | |
from_path="data.zip", | |
to_path="./", | |
remove_finished=False, | |
) | |
imagenet_hard = datasets.load_dataset("taesiri/imagenet-hard", split="validation") | |
def update_snapshot(): | |
output_dir = snapshot_download( | |
repo_id="taesiri/imagenet_hard_review_data", allow_patterns="*.json", repo_type="dataset" | |
) | |
total_size = len(imagenet_hard) | |
files = glob(f"{output_dir}/*.json") | |
df = pd.DataFrame() | |
columns = ["id", "user_id", "time", "decision"] | |
rows = [] | |
for file in files: | |
with open(file) as f: | |
data = json.load(f) | |
tdf = [data[x] for x in columns] | |
# add filename as a column | |
rows.append(tdf) | |
df = pd.DataFrame(rows, columns=columns) | |
return df, total_size | |
# df = update_snapshot() | |
NUMBER_OF_IMAGES = 1000 | |
# Function to sample 10 ids based on their usage count | |
def sample_ids(df, total_size, sample_size): | |
id_counts = df['id'].value_counts().to_dict() | |
all_ids = bad_items | |
for id in all_ids: | |
if id not in id_counts: | |
id_counts[id] = 0 | |
weights = [id_counts[id] for id in all_ids] | |
inverse_weights = [1 / (count + 1) for count in weights] | |
normalized_weights = [w / sum(inverse_weights) for w in inverse_weights] | |
sampled_ids = np.random.choice(all_ids, size=sample_size, replace=False, p=normalized_weights) | |
return sampled_ids | |
def generate_dataset(): | |
df, total_size = update_snapshot() | |
random_indices = sample_ids(df, total_size, NUMBER_OF_IMAGES) | |
random_images = [imagenet_hard[int(i)]["image"] for i in random_indices] | |
random_gt_ids = [imagenet_hard[int(i)]["label"] for i in random_indices] | |
random_gt_labels = [imagenet_hard[int(x)]["english_label"] for x in random_indices] | |
data = [] | |
for i, image in enumerate(random_images): | |
data.append( | |
{ | |
"id": random_indices[i], | |
"image": image, | |
"correct_label": random_gt_labels[i], | |
"original_id": int(random_indices[i]), | |
} | |
) | |
return data | |
def string_to_image(text): | |
text = text.replace('_', ' ').lower().replace(', ', '\n') | |
# Create a blank white square image | |
img = np.ones((220, 75, 3)) | |
# Create a figure and axis object | |
fig, ax = plt.subplots(figsize=(6, 2.25)) | |
# Plot the blank white image | |
ax.imshow(img, extent=[0, 1, 0, 1]) | |
# Set the text in the center | |
ax.text(0.5, 0.75, text, fontsize=18, ha='center', va='center') | |
# Remove the axis labels and ticks | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
ax.set_xticklabels([]) | |
ax.set_yticklabels([]) | |
# Remove the axis spines | |
for spine in ax.spines.values(): | |
spine.set_visible(False) | |
# Return the figure | |
return fig | |
def label_dist_of_nns(qid): | |
with open('./trainingset_filenames.json', 'r') as f: | |
trainingset_filenames = json.load(f) | |
nns = knn_results[qid][:15] | |
labels = [wnid_to_label[trainingset_filenames[f"{x}"]] for x in nns] | |
label_counts = {x: labels.count(x) for x in set(labels)} | |
# sort by count | |
label_counts = {k: v for k, v in sorted(label_counts.items(), key=lambda item: item[1], reverse=True)} | |
# percetage | |
label_counts = {k: v/len(labels) for k, v in label_counts.items()} | |
return label_counts | |
from glob import glob | |
all_samples = glob('./imagenet_traning_samples/*.JPEG') | |
qid_to_sample = {int(x.split('/')[-1].split('.')[0].split('_')[0]): x for x in all_samples} | |
def get_training_samples(qid): | |
labels_id = imagenet_hard[int(qid)]['label'] | |
samples = [qid_to_sample[x] for x in labels_id] | |
return samples | |
knn_cache_path = "knn_cache_for_imagenet_hard" | |
imagenet_training_samples_path = "imagenet_traning_samples" | |
def load_sample(data, current_index): | |
image_id = data[current_index]["id"] | |
qimage = data[current_index]["image"] | |
neighbors_path = os.path.join(knn_cache_path, f"{image_id}.JPEG") | |
neighbors_image = Image.open(neighbors_path).convert('RGB') | |
labels = data[current_index]["correct_label"] | |
return qimage, neighbors_image, labels | |
# return qimage, neighbors_image, training_samples_image | |
def update_app(decision, data, current_index, history, username): | |
if current_index == -1: | |
data = generate_dataset() | |
nns = {} | |
if current_index>=0 and current_index < NUMBER_OF_IMAGES-1: | |
time_stamp = int(time.time()) | |
image_id = data[current_index]["id"] | |
# convert to percentage | |
dicision_dict = { | |
"id": int(image_id), | |
"user_id": username, | |
"time": time_stamp, | |
"decision": decision, | |
} | |
# upload the decision to the server | |
temp_filename = f"results_{username}_{time_stamp}.json" | |
# convert decision_dict to json and save it on the disk | |
with open(temp_filename, "w") as f: | |
json.dump(dicision_dict, f) | |
api = HfApi() | |
api.upload_file( | |
path_or_fileobj=temp_filename, | |
path_in_repo=temp_filename, | |
repo_id="taesiri/imagenet_hard_review_data", | |
repo_type="dataset", | |
) | |
os.remove(temp_filename) | |
elif current_index == NUMBER_OF_IMAGES-1: | |
return None, None, None, current_index, history, data, None, None | |
current_index += 1 | |
qimage, neighbors_image, labels = load_sample(data, current_index) | |
image_id = data[current_index]["id"] | |
training_samples_image = get_training_samples(image_id) | |
training_samples_image = [Image.open(x).convert('RGB') for x in training_samples_image] | |
nns = label_dist_of_nns(image_id) | |
# labels is a list of labels, conver it to a string | |
labels = ", ".join(labels) | |
label_plot = string_to_image(labels) | |
return qimage, label_plot, neighbors_image, current_index, history, data, nns, training_samples_image | |
newcss = ''' | |
#query_image{ | |
height: auto !important; | |
} | |
#nn_gallery { | |
height: auto !important; | |
} | |
#sample_gallery { | |
height: auto !important; | |
} | |
''' | |
with gr.Blocks(css=newcss) as demo: | |
data_gr = gr.State({}) | |
current_index = gr.State(-1) | |
history = gr.State({}) | |
gr.Markdown("# Cleaning ImageNet-Hard!") | |
random_str = "".join( | |
random.choice(string.ascii_lowercase + string.digits) for _ in range(5) | |
) | |
username = gr.Textbox(label="Username", value=f"user-{random_str}") | |
with gr.Column(): | |
with gr.Row(): | |
accept_btn = gr.Button(value="Accept") | |
myabe_btn = gr.Button(value="Not Sure!") | |
reject_btn = gr.Button(value="Reject") | |
with gr.Row(): | |
query_image = gr.Image(type="pil", label="Query", elem_id="query_image") | |
with gr.Column(): | |
label_plot = gr.Plot(label='Is this a correct label for this image?', type='fig') | |
training_samples = gr.Gallery(type="pil", label="Training samples" , elem_id="sample_gallery") | |
with gr.Column(): | |
gr.Markdown("## Nearest Neighbors Analysis of the Query (ResNet-50)") | |
nn_labels = gr.Label(label="NN-Labels") | |
neighbors_image = gr.Image(type="pil", label="Nearest Neighbors", elem_id="nn_gallery") | |
accept_btn.click( | |
update_app, | |
inputs=[accept_btn, data_gr, current_index, history, username], | |
outputs=[query_image, label_plot, neighbors_image, current_index, history, data_gr, nn_labels, training_samples] | |
) | |
myabe_btn.click( | |
update_app, | |
inputs=[myabe_btn, data_gr, current_index, history, username], | |
outputs=[query_image, label_plot, neighbors_image, current_index, history, data_gr, nn_labels, training_samples] | |
) | |
reject_btn.click( | |
update_app, | |
inputs=[reject_btn, data_gr, current_index, history, username], | |
outputs=[query_image, label_plot, neighbors_image, current_index, history, data_gr, nn_labels, training_samples] | |
) | |
demo.launch() | |