Spaces:
Runtime error
Runtime error
# Basic example for doing model-in-the-loop dynamic adversarial data collection | |
# using Gradio Blocks. | |
import random | |
from urllib.parse import parse_qs | |
import gradio as gr | |
import requests | |
from transformers import pipeline | |
pipe = pipeline("sentiment-analysis") | |
demo = gr.Blocks() | |
with demo: | |
total_cnt = 2 # How many examples per HIT | |
dummy = gr.Textbox(visible=False) # dummy for passing assignmentId | |
# We keep track of state as a Variable | |
state_dict = {"assignmentId": "", "cnt": 0, "fooled": 0, "data": [], "metadata": {}} | |
state = gr.Variable(state_dict) | |
gr.Markdown("# DADC in Gradio example") | |
gr.Markdown("Try to fool the model and find an example where it predicts the wrong label!") | |
state_display = gr.Markdown(f"State: 0/{total_cnt} (0 fooled)") | |
# Generate model prediction | |
# Default model: distilbert-base-uncased-finetuned-sst-2-english | |
def _predict(txt, tgt, state): | |
pred = pipe(txt)[0] | |
other_label = 'negative' if pred['label'].lower() == "positive" else "positive" | |
pred_confidences = {pred['label'].lower(): pred['score'], other_label: 1 - pred['score']} | |
pred["label"] = pred["label"].title() | |
ret = f"Target: **{tgt}**. Model prediction: **{pred['label']}**\n\n" | |
if pred["label"] != tgt: | |
state["fooled"] += 1 | |
ret += " You fooled the model! Well done!" | |
else: | |
ret += " You did not fool the model! Too bad, try again!" | |
state["data"].append(ret) | |
state["cnt"] += 1 | |
done = state["cnt"] == total_cnt | |
toggle_final_submit = gr.update(visible=done) | |
toggle_example_submit = gr.update(visible=not done) | |
new_state_md = f"State: {state['cnt']}/{total_cnt} ({state['fooled']} fooled)" | |
return pred_confidences, ret, state, toggle_example_submit, toggle_final_submit, new_state_md | |
# Input fields | |
text_input = gr.Textbox(placeholder="Enter model-fooling statement", show_label=False) | |
labels = ["Positive", "Negative"] | |
random.shuffle(labels) | |
label_input = gr.Radio(choices=labels, label="Target (correct) label") | |
label_output = gr.Label() | |
text_output = gr.Markdown() | |
with gr.Column() as example_submit: | |
submit_ex_button = gr.Button("Submit") | |
with gr.Column(visible=False) as final_submit: | |
submit_hit_button = gr.Button("Submit HIT") | |
# Submit state to MTurk backend for ExternalQuestion | |
# Update the URL below to switch from Sandbox to real data collection | |
def _submit(state, dummy): | |
query = parse_qs(dummy[1:]) | |
print(dummy) | |
print(query) | |
assert "assignment_id" in query, "No assignment ID provided, unable to submit" | |
state["assignment_id"] = query["assignment_id"] | |
url = "https://workersandbox.mturk.com/mturk/externalSubmit" | |
return requests.post(url, data=state) | |
# Button event handlers | |
submit_ex_button.click( | |
_predict, | |
inputs=[text_input, label_input, state], | |
outputs=[label_output, text_output, state, example_submit, final_submit, state_display], | |
) | |
submit_hit_button.click( | |
_submit, | |
inputs=[state, dummy], | |
outputs=None, | |
_js="function(state, dummy) { return [state, window.location.search]; }", | |
) | |
demo.launch(share=True) |