Spaces:
Runtime error
Runtime error
import json | |
import os | |
import threading | |
import uuid | |
from pathlib import Path | |
from urllib.parse import parse_qs | |
from datasets import load_dataset | |
import gradio as gr | |
from dotenv import load_dotenv | |
from huggingface_hub import Repository | |
import random | |
from utils import force_git_push | |
# These variables are for storing the MTurk HITs in a Hugging Face dataset. | |
if Path(".env").is_file(): | |
load_dotenv(".env") | |
DATASET_REPO_URL = os.getenv("DATASET_REPO_URL") | |
FORCE_PUSH = os.getenv("FORCE_PUSH") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
PROMPT_TEMPLATES = Path("prompt_templates") | |
DATA_FILENAME = "data.jsonl" | |
DATA_FILE = os.path.join("data", DATA_FILENAME) | |
repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN) | |
ds = load_dataset("HuggingFaceH4/instruction-pilot-outputs", split="train", use_auth_token=HF_TOKEN) | |
TOTAL_CNT = 10 # How many user inputs per HIT | |
# This function pushes the HIT data written in data.jsonl to our Hugging Face | |
# dataset every minute. Adjust the frequency to suit your needs. | |
PUSH_FREQUENCY = 60 | |
def asynchronous_push(f_stop): | |
if repo.is_repo_clean(): | |
print("Repo currently clean. Ignoring push_to_hub") | |
else: | |
repo.git_add(auto_lfs_track=True) | |
repo.git_commit("Auto commit by space") | |
if FORCE_PUSH == "yes": | |
force_git_push(repo) | |
else: | |
repo.git_push() | |
if not f_stop.is_set(): | |
# call again in 60 seconds | |
threading.Timer(PUSH_FREQUENCY, asynchronous_push, [f_stop]).start() | |
f_stop = threading.Event() | |
asynchronous_push(f_stop) | |
demo = gr.Blocks() | |
def random_sample_with_least_annotated_examples_first(): | |
annotations = open(DATA_FILE, "r").readlines() | |
id_to_count = {} | |
for line in annotations: | |
annotation = json.loads(line) | |
# Only include annotations by actual turkers in the count. | |
if annotation["assignmentId"] != "": | |
example_id = annotation["id"] | |
id_to_count[example_id] = id_to_count.get(example_id, 0) + 1 | |
ds_with_annotation_counts = ds.map(lambda example: {"annotation_count": id_to_count.get(example["id"], 0)}) | |
ds_with_annotation_counts = ds_with_annotation_counts.shuffle() | |
ds_with_annotation_counts = ds_with_annotation_counts.sort("annotation_count") | |
example = ds_with_annotation_counts.select([0])[0] | |
# We only want to give the annotator 2 choices, so we sample 2 outputs without replacement. | |
example["outputs"] = random.sample(example["outputs"], 2) | |
return example | |
def prompt_pretty_markdown(prompt): | |
prompt = prompt.replace("Input:", "\n\n<b>Input:</b>\n\n") | |
return prompt | |
with demo: | |
dummy = gr.Textbox(visible=False) # dummy for passing assignmentId | |
initial_sample = random_sample_with_least_annotated_examples_first() | |
# We keep track of state as a JSON | |
state_dict = { | |
"taskId": str(uuid.uuid4()), | |
"assignmentId": "", | |
"cnt": 0, | |
"data": [initial_sample], | |
} | |
state = gr.JSON(state_dict, visible=False) | |
gr.Markdown("# Choose the more helpful response for the input") | |
gr.Markdown("By 'helpful', we mean whatever answer you personally find more useful.") | |
def _select_response(selected_response, state, dummy): | |
if selected_response == "": | |
# Don't do anything if the worker didn't select things yet. | |
return ( | |
gr.update(), | |
gr.update(), | |
gr.update(), | |
gr.update(), | |
gr.update(), | |
gr.update(), | |
gr.update(), | |
state, | |
dummy, | |
) | |
state["cnt"] += 1 | |
state_display = f"Submissions left in HIT: {state['cnt']}/{TOTAL_CNT}" | |
done = state["cnt"] == TOTAL_CNT | |
state["data"][-1]["selected_response"] = selected_response | |
if state["cnt"] == TOTAL_CNT: | |
# Write the HIT data to our local dataset because the worker has | |
# submitted everything now. | |
with open(DATA_FILE, "a") as jsonlfile: | |
json_data_with_assignment_id = [ | |
json.dumps( | |
dict( | |
{"assignmentId": state["assignmentId"], "taskId": state["taskId"]}, | |
**datum, | |
) | |
) | |
for datum in state["data"] | |
] | |
jsonlfile.write("\n".join(json_data_with_assignment_id) + "\n") | |
query = parse_qs(dummy[1:]) | |
if "assignmentId" in query and query["assignmentId"][0] != "ASSIGNMENT_ID_NOT_AVAILABLE": | |
# It seems that someone is using this app on mturk. We need to | |
# store the assignmentId in the state before submit_hit_button | |
# is clicked. We can do this here in _predict. We need to save the | |
# assignmentId so that the turker can get credit for their HIT. | |
state["assignmentId"] = query["assignmentId"][0] | |
toggle_final_submit = gr.update(visible=done) | |
toggle_final_submit_preview = gr.update(visible=False) | |
else: | |
toggle_final_submit_preview = gr.update(visible=done) | |
toggle_final_submit = gr.update(visible=False) | |
toggle_submit_response_button = gr.update(visible=not done) | |
new_sample = random_sample_with_least_annotated_examples_first() | |
new_outputs = [obj["output"] for obj in new_sample["outputs"]] | |
state["data"].append(new_sample) | |
past_conversation = gr.update( | |
value=prompt_pretty_markdown(new_sample["prompt"]) | |
) | |
select_response = gr.update(choices=["(a) " + new_outputs[0], "(b) " + new_outputs[1], "(c) Both (a) and (b) are similarly good", "(d) Both (a) and (b) are similarly bad"], value="") | |
return ( | |
past_conversation, | |
select_response, | |
toggle_submit_response_button, | |
toggle_final_submit, | |
toggle_final_submit_preview, | |
state_display, | |
state, | |
dummy, | |
) | |
# Input fields | |
gr.Markdown('<span style="padding:7px;color:black;background:#ffd21e;border-radius:10px"><b>Prompt</b></span>') | |
past_conversation = gr.Markdown( | |
value=prompt_pretty_markdown(initial_sample["prompt"]) | |
) | |
initial_outputs = [obj["output"] for obj in initial_sample["outputs"]] | |
gr.Markdown('<span style="padding:7px;color:black;background:#ffd21e;border-radius:10px"><b>Select the most helpful response</b></span>') | |
select_response = gr.Radio( | |
choices=["(a) " + initial_outputs[0], "(b) " + initial_outputs[1], "(c) Both (a) and (b) are similarly good", "(d) Both (a) and (b) are similarly bad"], label="", | |
) | |
submit_response_button = gr.Button("Submit Response") | |
submit_hit_button = gr.Button("Submit HIT", visible=False) | |
submit_hit_button_preview = gr.Button( | |
"Submit Work (preview mode; no MTurk HIT credit, but your examples will still be stored)", | |
visible=False, | |
) | |
state_display = gr.Markdown(f"Submissions left in HIT: 0/{TOTAL_CNT}") | |
# Button event handlers | |
get_window_location_search_js = """ | |
function(select_response, state, dummy) { | |
return [select_response, state, window.location.search]; | |
} | |
""" | |
submit_response_button.click( | |
_select_response, | |
inputs=[select_response, state, dummy], | |
outputs=[ | |
past_conversation, | |
select_response, | |
submit_response_button, | |
submit_hit_button, | |
submit_hit_button_preview, | |
state_display, | |
state, | |
dummy, | |
], | |
_js=get_window_location_search_js, | |
) | |
post_hit_js = """ | |
function(state) { | |
// If there is an assignmentId, then the submitter is on mturk | |
// and has accepted the HIT. So, we need to submit their HIT. | |
const form = document.createElement('form'); | |
form.action = 'https://workersandbox.mturk.com/mturk/externalSubmit'; | |
form.method = 'post'; | |
for (const key in state) { | |
const hiddenField = document.createElement('input'); | |
hiddenField.type = 'hidden'; | |
hiddenField.name = key; | |
hiddenField.value = state[key]; | |
form.appendChild(hiddenField); | |
}; | |
document.body.appendChild(form); | |
form.submit(); | |
return state; | |
} | |
""" | |
submit_hit_button.click( | |
lambda state: state, | |
inputs=[state], | |
outputs=[state], | |
_js=post_hit_js, | |
) | |
refresh_app_js = """ | |
function(state) { | |
// The following line here loads the app again so the user can | |
// enter in another preview-mode "HIT". | |
window.location.href = window.location.href; | |
return state; | |
} | |
""" | |
submit_hit_button_preview.click( | |
lambda state: state, | |
inputs=[state], | |
outputs=[state], | |
_js=refresh_app_js, | |
) | |
demo.launch() | |