import math import os import random import uuid from datetime import datetime import gradio as gr import jsonlines import pyarrow as pa import s3fs from datasets import Dataset from huggingface_hub import HfApi S3 = s3fs.S3FileSystem(anon=False, key=os.getenv("AWS_ACCESS_KEY_ID"), secret=os.getenv("AWS_SECRET_ACCESS_KEY")) DEFAULT_SHUFFLE_BUFFER_SIZE_RATIO = 5 BASE_S3_DIR = "s3://geclm-datasets/samples/" DATASETS = [ "c4", "bigcode_python_code", "bigcode_python_github_issues", "bigcode_python_jupyter_markdowned_clean_dedup", "books3", "gutenberg_raw", "reddit_threaded", "enwiki_data", "s2orc_dedup", "stackexchange2", "commoncrawl", ] def get_parquet_lines(dataset, sample_size=100): s3_paths = S3.glob(BASE_S3_DIR + dataset + "/*") if len(s3_paths) == 0: raise FileNotFoundError(f"Nothing found at {path}") print("Number of parquet files", len(s3_paths)) s3_path = random.choice(s3_paths) print("Reading", s3_path) lines = [] with S3.open(s3_path) as f: pf = pa.parquet.ParquetFile(f) for ix_row_group in range(pf.metadata.num_row_groups): # We load dataset by row group - 1000 rows at a time # using open_input_stream would return bytes per bytes not row per row table = pf.read_row_group(ix_row_group) lines.extend(table.to_pylist()) random.shuffle(lines) return lines[:sample_size] def get_local_lines(dataset): lines = [] with jsonlines.open("data/{}_examples_with_stats.json".format(dataset), "r") as f: for line in f: lines.append(line) return lines def line_generator(lines_dict, dataset): for line in lines_dict[dataset]: yield line # Parallelize the below local_lines = {dataset: get_local_lines(dataset) for dataset in DATASETS} s3_lines = {dataset: get_parquet_lines(dataset) for dataset in DATASETS} line_generators_local = {dataset: line_generator(local_lines, dataset) for dataset in DATASETS} line_generators_s3 = {dataset: line_generator(s3_lines, dataset) for dataset in DATASETS} def send_report(sample, dataset, reason, annotator, campaign): text = sample["text"] sample.pop("text") sample_id = "" if "id" not in sample: if "title" in sample: sample_id = sample["title"] else: sample_id = sample["id"] with jsonlines.open("report.jsonl", "w") as f: f.write( { "dataset": dataset, "docid": sample_id, "text": text, "metadata": sample, "reason": reason, "annotator": annotator, "campaign": campaign, "timestamp": str(datetime.now()), } ) api = HfApi() api.upload_file( path_or_fileobj="report.jsonl", path_in_repo="report-{}.jsonl".format(uuid.uuid4()), repo_id="HuggingFaceGECLM/data_feedback", repo_type="dataset", token=os.environ.get("geclm_token"), ) description = """ GecLM annotations. All annotations are recorded in the [data_feedback](https://huggingface.co/datasets/HuggingFaceGECLM/data_feedback) dataset. """ if __name__ == "__main__": demo = gr.Blocks() with demo: current_sample_state = gr.State(dict()) description = gr.Markdown(value=description) with gr.Row(): annotator = gr.Textbox( lines=1, max_lines=1, placeholder="Optionally provide your name here if you'd like it to be recorded.", label="Annotator", ) campaign = gr.Textbox( lines=1, max_lines=1, placeholder="Optionally provide the name of the annotation campagin for ease of filtering the reports.", label="Annotation campaign", ) with gr.Row(): dataset = gr.Dropdown( choices=DATASETS, value="Pick a dataset below", label="Dataset", ) with gr.Row(): reason_txt = gr.Textbox( label="Flagging reason", placeholder="Provide the reason for flagging if you think the sample is bad.", visible=False, ) with gr.Row(): bad_btn = gr.Button("Bad ❌", visible=False) good_btn = gr.Button("Next ✅", visible=False) with gr.Row(): text = gr.Textbox(visible=False, label="Datapoint", lines=500) def next_line(dataset): next_line = next(line_generators_s3[dataset]) text_col = "text" if text_col not in next_line: text_col = "content" return [ gr.update(value=next_line[text_col], visible=True), next_line, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), ] def bad_line(current_sample, dataset, reason, annotator, campaign): send_report(current_sample, dataset, reason, annotator, campaign) next_line = next(line_generators_s3[dataset]) text_col = "text" if text_col not in next_line: text_col = "content" return [ next_line[text_col], gr.update( value="", placeholder="Provide the reason for flagging if you think the sample is bad.", ), next_line, ] good_btn.click( next_line, inputs=dataset, outputs=[text, current_sample_state, reason_txt, good_btn, bad_btn], ) dataset.change( next_line, inputs=dataset, outputs=[text, current_sample_state, reason_txt, good_btn, bad_btn], ) bad_btn.click( bad_line, inputs=[current_sample_state, dataset, reason_txt, annotator, campaign], outputs=[text, reason_txt, current_sample_state], ) demo.launch(enable_queue=False, debug=True)