Spaces:
Runtime error
Runtime error
# Basic example for doing model-in-the-loop dynamic adversarial data collection | |
# using Gradio Blocks. | |
import os | |
import random | |
from urllib.parse import parse_qs | |
import gradio as gr | |
import requests | |
from transformers import pipeline | |
from huggingface_hub import Repository | |
from dotenv import load_dotenv | |
from pathlib import Path | |
import json | |
# These variables are for storing the mturk HITs in a Hugging Face dataset. | |
if Path(".env").is_file(): | |
load_dotenv(".env") | |
DATASET_REPO_URL = os.getenv("DATASET_REPO_URL") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
DATA_FILENAME = "data.jsonl" | |
DATA_FILE = os.path.join("data", DATA_FILENAME) | |
repo = Repository( | |
local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN | |
) | |
# Now let's run the app! | |
pipe = pipeline("sentiment-analysis") | |
demo = gr.Blocks() | |
with demo: | |
total_cnt = 2 # How many examples per HIT | |
dummy = gr.Textbox(visible=False) # dummy for passing assignmentId | |
# We keep track of state as a JSON | |
state_dict = {"assignmentId": "", "cnt": 0, "cnt_fooled": 0, "data": []} | |
state = gr.JSON(state_dict, visible=False) | |
gr.Markdown("# DADC in Gradio example") | |
gr.Markdown("Try to fool the model and find an example where it predicts the wrong label!") | |
state_display = gr.Markdown(f"State: 0/{total_cnt} (0 fooled)") | |
# Generate model prediction | |
# Default model: distilbert-base-uncased-finetuned-sst-2-english | |
def _predict(txt, tgt, state, dummy): | |
pred = pipe(txt)[0] | |
other_label = 'negative' if pred['label'].lower() == "positive" else "positive" | |
pred_confidences = {pred['label'].lower(): pred['score'], other_label: 1 - pred['score']} | |
pred["label"] = pred["label"].title() | |
ret = f"Target: **{tgt}**. Model prediction: **{pred['label']}**\n\n" | |
fooled = pred["label"] != tgt | |
if fooled: | |
state["cnt_fooled"] += 1 | |
ret += " You fooled the model! Well done!" | |
else: | |
ret += " You did not fool the model! Too bad, try again!" | |
state["data"].append({"cnt": state["cnt"], "text": txt, "target": tgt, "model_pred": pred["label"], "fooled": fooled}) | |
state["cnt"] += 1 | |
done = state["cnt"] == total_cnt | |
toggle_final_submit = gr.update(visible=done) | |
toggle_example_submit = gr.update(visible=not done) | |
new_state_md = f"State: {state['cnt']}/{total_cnt} ({state['cnt_fooled']} fooled)" | |
query = parse_qs(dummy[1:]) | |
if "assignmentId" in query: | |
# It seems that someone is using this app on mturk. We need to | |
# store the assignmentId in the state before submit_hit_button | |
# is clicked. We can do this here in _predict. We need to save the | |
# assignmentId so that the turker can get credit for their HIT. | |
state["assignmentId"] = query["assignmentId"][0] | |
return pred_confidences, ret, state, toggle_example_submit, toggle_final_submit, new_state_md, dummy | |
# Input fields | |
text_input = gr.Textbox(placeholder="Enter model-fooling statement", show_label=False) | |
labels = ["Positive", "Negative"] | |
random.shuffle(labels) | |
label_input = gr.Radio(choices=labels, label="Target (correct) label") | |
label_output = gr.Label() | |
text_output = gr.Markdown() | |
with gr.Column() as example_submit: | |
submit_ex_button = gr.Button("Submit") | |
with gr.Column(visible=False) as final_submit: | |
submit_hit_button = gr.Button("Submit HIT") | |
# Store the HIT data into a Hugging Face dataset. | |
# The HIT is also stored and logged on mturk when post_hit_js is run below. | |
# This _store_in_huggingface_dataset function just demonstrates how easy it is | |
# to automatically create a Hugging Face dataset from mturk. | |
def _store_in_huggingface_dataset(state): | |
with open(DATA_FILE, "a") as jsonlfile: | |
json_data_with_assignment_id =\ | |
[json.dumps(dict({"assignmentId": state["assignmentId"]}, **datum)) for datum in state["data"]] | |
jsonlfile.write("\n".join(json_data_with_assignment_id) + "\n") | |
repo.push_to_hub() | |
return state | |
# Button event handlers | |
get_window_location_search_js = """ | |
function(text_input, label_input, state, dummy) { | |
return [text_input, label_input, state, window.location.search]; | |
} | |
""" | |
submit_ex_button.click( | |
_predict, | |
inputs=[text_input, label_input, state, dummy], | |
outputs=[label_output, text_output, state, example_submit, final_submit, state_display, dummy], | |
_js=get_window_location_search_js, | |
) | |
post_hit_js = """ | |
function(state) { | |
if (state["assignmentId"] !== ""){ | |
// If there is an assignmentId, then the submitter is on mturk | |
// and we need to submit their HIT. | |
const form = document.createElement('form'); | |
form.action = 'https://workersandbox.mturk.com/mturk/externalSubmit'; | |
form.method = 'post'; | |
for (const key in state) { | |
const hiddenField = document.createElement('input'); | |
hiddenField.type = 'hidden'; | |
hiddenField.name = key; | |
hiddenField.value = state[key]; | |
form.appendChild(hiddenField); | |
}; | |
document.body.appendChild(form); | |
form.submit(); | |
return [state]; | |
} else { | |
// If there is no assignmentId, then we assume that the submitter is | |
// on huggingface.co and we can't submit a HIT to mturk. But | |
// _store_in_huggingface_dataset will still store their example in | |
// our dataset without an assignmentId. The following line here | |
// loads the app again so the user can enter in another "fake" HIT. | |
window.location.href = window.location.href; | |
} | |
} | |
""" | |
submit_hit_button.click( | |
_store_in_huggingface_dataset, | |
inputs=[state], | |
outputs=[state], | |
_js=post_hit_js, | |
) | |
demo.launch() |