Spaces:
Sleeping
Sleeping
import json | |
import gradio as gr | |
import numpy as np | |
import time | |
import csv | |
import json | |
import os | |
import random | |
import string | |
import sys | |
import time | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
from huggingface_hub import ( | |
CommitScheduler, | |
HfApi, | |
InferenceClient, | |
login, | |
snapshot_download, | |
) | |
from PIL import Image | |
from utils import string_to_image | |
import matplotlib.backends.backend_agg as agg | |
import math | |
from pathlib import Path | |
np.random.seed(int(time.time())) | |
csv.field_size_limit(sys.maxsize) | |
np.random.seed(int(time.time())) | |
NUMBER_OF_IMAGES = 30 | |
intro_screen = Image.open("./images/intro.jpg") | |
meta_top1 = json.load(open("./data/top1-exp-materials/metadata.json")) | |
meta_topK = json.load(open("./data/topK-exp-materials/metadata.json")) | |
all_data = {} | |
all_data["top1"] = meta_top1 | |
all_data["topK"] = meta_topK | |
# for data in all_data["top1"] and all_data["topK"] add a key to show which type they are | |
for k in all_data["top1"].keys(): | |
all_data["top1"][k]["type"] = "top1" | |
for k in all_data["topK"].keys(): | |
all_data["topK"][k]["type"] = "topK" | |
############################################################################################################### | |
session_token = os.environ.get("SessionToken") | |
login(token=session_token) | |
REPO_URL = "xairesearch2023-advnet/HumanStudyData" | |
JSON_DATASET_DIR = Path("responses") | |
################################################################################################################ | |
scheduler = CommitScheduler( | |
repo_id=REPO_URL, | |
repo_type="dataset", | |
folder_path=JSON_DATASET_DIR, | |
path_in_repo="./data", | |
every=1, | |
private=True, | |
) | |
if not JSON_DATASET_DIR.exists(): | |
JSON_DATASET_DIR.mkdir() | |
def generate_data(type_of_nns): | |
global NUMBER_OF_IMAGES | |
if type_of_nns == "topQ": | |
type_of_nns = "topK" | |
# randomly pick NUMBER_OF_IMAGES from the dataset with type type_of_nns | |
keys = list(all_data[type_of_nns].keys()) | |
sample_data = random.sample(keys, NUMBER_OF_IMAGES) | |
data = [] | |
for k in sample_data: | |
new_datapoint = all_data[type_of_nns][k] | |
new_datapoint["image-path"] = f"./data/{type_of_nns}-exp-materials/{k}.jpeg" | |
data.append(new_datapoint) | |
return data | |
def load_sample(data, current_index): | |
current_datapoint = data[current_index] | |
image_path = current_datapoint["image-path"] | |
image = Image.open(image_path) | |
top_1 = current_datapoint["top1-label"] | |
top_1_score = current_datapoint["top1-score"] | |
q_template = ( | |
"<div style='font-size: 24px;'>Sam guessed the Input image is " | |
"<span style='font-weight: bold;'>{}</span> " | |
"with <span style='font-weight: bold;'>{}%</span> " | |
"confidence. Is this bird a <span style='font-weight: bold;'>{}</span>?" | |
"</div>" | |
) | |
q_template = ( | |
"<div style='font-size: 24px;'>Sam guessed the Input image is " | |
"<span style='font-weight: bold;'>{}</span> " | |
"with <span style='font-weight: bold;'>{}%</span> " | |
"confidence.<br>Is this bird a <span style='font-weight: bold;'>{}</span>?" | |
"</div>" | |
) | |
top_1_score = top_1_score * 100 | |
top_1_score = round(top_1_score, 2) | |
rounded_up_score = math.ceil(top_1_score) | |
rounded_up_score = int(rounded_up_score) | |
question = q_template.format(top_1, str(rounded_up_score), top_1) | |
accept_reject = current_datapoint["Accept/Reject"] | |
return image, top_1, rounded_up_score, question, accept_reject | |
def preprocessing(data, type_of_nns, current_index, history, username): | |
print("preprocessing") | |
data = generate_data(type_of_nns) | |
print("data generated") | |
# append a random text to the username | |
random_text = "".join( | |
random.choice(string.ascii_lowercase + string.digits) for _ in range(8) | |
) | |
if username == "": | |
username = "username" | |
username = f"{username}-{random_text}" | |
current_index = 0 | |
print("loading sample ....") | |
qimage, top_1, top_1_score, question, accept_reject = load_sample( | |
data, current_index | |
) | |
return ( | |
qimage, | |
top_1, | |
top_1_score, | |
question, | |
accept_reject, | |
current_index, | |
history, | |
data, | |
username, | |
) | |
def update_app(decision, data, current_index, history, username): | |
global NUMBER_OF_IMAGES | |
if current_index == -1: | |
gr.Error("Please Enter your username and load samples") | |
fake_plot = string_to_image("Please Enter your username and load samples") | |
canvas = agg.FigureCanvasAgg(fake_plot) | |
canvas.draw() | |
empty_image = Image.frombytes( | |
"RGBA", canvas.get_width_height(), canvas.tostring_argb() | |
) | |
return ( | |
empty_image, | |
"", | |
"", | |
"", | |
"", | |
current_index, | |
history, | |
data, | |
0, | |
gr.update(interactive=False), | |
gr.update(interactive=False), | |
"", | |
) | |
# Done, let's save and upload | |
if current_index == NUMBER_OF_IMAGES - 1: | |
time_stamp = int(time.time()) | |
# Add decision to the history | |
current_dicitonary = data[current_index].copy() | |
current_dicitonary["user_decision"] = decision | |
current_dicitonary["user_id"] = username | |
accept_reject_string = "Accept" if decision == "YES" else "Reject" | |
current_dicitonary["is_user_correct"] = ( | |
current_dicitonary["Accept/Reject"] == accept_reject_string | |
) | |
history.append(current_dicitonary) | |
# convert to percentage | |
final_decision_data = { | |
"user_id": username, | |
"time": time_stamp, | |
"history": history, | |
} | |
# upload the decision to the server | |
temp_filename = f"./responses/results_{username}.json" | |
# convert decision_dict to json and save it on the disk | |
with open(temp_filename, "w") as f: | |
json.dump(final_decision_data, f) | |
fake_plot = string_to_image("Thank you for your time!") | |
canvas = agg.FigureCanvasAgg(fake_plot) | |
canvas.draw() | |
empty_image = Image.frombytes( | |
"RGBA", canvas.get_width_height(), canvas.tostring_argb() | |
) | |
# TODO, Call the accuracy and show it to the user | |
# calcualte the mean of is_user_correct | |
all_is_user_correct = [d["is_user_correct"] for d in history] | |
accuracy = np.mean(all_is_user_correct) * 100 | |
accuracy = round(accuracy, 2) | |
return ( | |
empty_image, | |
"", | |
"", | |
"", | |
"", | |
current_index, | |
history, | |
data, | |
current_index + 1, | |
gr.update(interactive=False), | |
gr.update(interactive=False), | |
f"User Accuracy: {accuracy}", | |
) | |
if current_index >= 0 and current_index < NUMBER_OF_IMAGES - 1: | |
current_dicitonary = data[current_index].copy() | |
current_dicitonary["user_decision"] = decision | |
current_dicitonary["user_id"] = username | |
accept_reject_string = True if decision == "YES" else False | |
current_dicitonary["is_user_correct"] = ( | |
current_dicitonary["Accept/Reject"] == accept_reject_string | |
) | |
print(f" accept/reject : {current_dicitonary['Accept/Reject'] }") | |
print( | |
f" accept/reject status: {current_dicitonary['Accept/Reject'] == accept_reject_string}" | |
) | |
history.append(current_dicitonary) | |
current_index += 1 | |
qimage, top_1, top_1_score, question, accept_reject = load_sample( | |
data, current_index | |
) | |
return ( | |
qimage, | |
top_1, | |
top_1_score, | |
question, | |
accept_reject, | |
current_index, | |
history, | |
data, | |
current_index, | |
gr.update(interactive=True), | |
gr.update(interactive=True), | |
"", | |
) | |
def disable_component(): | |
return gr.update(interactive=False) | |
def enable_component(): | |
return gr.update(interactive=True) | |
def hide_component(): | |
return gr.update(visible=False) | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
data_state = gr.State({}) | |
current_index = gr.State(-1) | |
history = gr.State([]) | |
gr.Markdown("# Advising Networks") | |
gr.Markdown("## Accept/Reject AI predicted label using Explanations") | |
with gr.Column(): | |
with gr.Row(): | |
username_textbox = gr.Textbox(label="Username", value=f"username") | |
labeled_images_textbox = gr.Textbox(label="Labeled Images", value="0") | |
total_images_textbox = gr.Textbox( | |
label="Total Images", value=NUMBER_OF_IMAGES | |
) | |
type_of_nns_dropdown = gr.Dropdown( | |
label="Type of NNs", | |
choices=["top1", "topQ"], | |
value="top1", | |
) | |
prepare_btn = gr.Button(value="Start The Experiment") | |
with gr.Column(): | |
with gr.Row(): | |
question_textbox = gr.HTML("") | |
# question_textbox = gr.Markdown("") | |
with gr.Column(elem_id="parent_row"): | |
query_image = gr.Image( | |
type="pil", label="Query", show_label=False, value="./images/intro.jpg" | |
) | |
with gr.Row(): | |
accept_btn = gr.Button(value="YES", interactive=False) | |
reject_btn = gr.Button(value="NO", interactive=False) | |
with gr.Column(elem_id="parent_row"): | |
top_1_textbox = gr.Textbox(label="Top 1", value="", visible=False) | |
top_1_score_textbox = gr.Textbox( | |
label="Top 1 Score", value="", visible=False | |
) | |
accept_reject_textbox = gr.Textbox( | |
label="Accept/Reject", value="", visible=False | |
) | |
with gr.Column(): | |
with gr.Row(): | |
final_results = gr.HTML("") | |
# data, type_of_nns, current_index, history | |
prepare_btn.click( | |
preprocessing, | |
inputs=[ | |
data_state, | |
type_of_nns_dropdown, | |
current_index, | |
history, | |
username_textbox, | |
], | |
outputs=[ | |
query_image, | |
top_1_textbox, | |
top_1_score_textbox, | |
question_textbox, | |
accept_reject_textbox, | |
current_index, | |
history, | |
data_state, | |
username_textbox, | |
], | |
).then(fn=disable_component, outputs=[prepare_btn]).then( | |
fn=disable_component, outputs=[type_of_nns_dropdown] | |
).then( | |
fn=disable_component, outputs=[username_textbox] | |
).then( | |
fn=disable_component, outputs=[prepare_btn] | |
).then( | |
fn=enable_component, outputs=[accept_btn] | |
).then( | |
fn=enable_component, outputs=[reject_btn] | |
).then( | |
fn=hide_component, outputs=[prepare_btn] | |
) | |
accept_btn.click( | |
update_app, | |
inputs=[accept_btn, data_state, current_index, history, username_textbox], | |
outputs=[ | |
query_image, | |
top_1_textbox, | |
top_1_score_textbox, | |
question_textbox, | |
accept_reject_textbox, | |
current_index, | |
history, | |
data_state, | |
labeled_images_textbox, | |
accept_btn, | |
reject_btn, | |
final_results, | |
], | |
) | |
reject_btn.click( | |
update_app, | |
inputs=[reject_btn, data_state, current_index, history, username_textbox], | |
outputs=[ | |
query_image, | |
top_1_textbox, | |
top_1_score_textbox, | |
question_textbox, | |
accept_reject_textbox, | |
current_index, | |
history, | |
data_state, | |
labeled_images_textbox, | |
accept_btn, | |
reject_btn, | |
final_results, | |
], | |
) | |
demo.launch(debug=False, server_name="0.0.0.0") | |
# demo.launch(debug=False) | |