Spaces:
Sleeping
Sleeping
import json | |
import time | |
from itertools import count, islice | |
from multiprocessing.pool import ThreadPool | |
from queue import Queue, Empty | |
from typing import Any, Callable, Iterable, Iterator, TypeVar | |
import gradio as gr | |
import ijson | |
import pandas as pd | |
import requests | |
from datasets import Features, Value, Sequence | |
from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
from huggingface_hub import InferenceClient | |
from utils import StringIteratorIO | |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
client = InferenceClient(model_id) | |
session = requests.Session() | |
empty_dataframe = pd.DataFrame({"1": [], "2": [], "3": []}) | |
NUM_ROWS_PREVIEW = 3 | |
REWRITE_DATASET = ( | |
"A Machine Learning practitioner is looking for a dataset similar to '{dataset}' but slightly different. " | |
"They want you to rewrite the dataset and apply this transformation: {prompt}." | |
"The first rows of the dataset are below in JSON format (one JSON object per line):\n\n{rows}\n\n" | |
"Rewrite those rows from the '{dataset}' dataset using the same format (one JSON object per line). " | |
"Try to keep some of the text or meaning intact, and apply the requested transformation '{prompt}'." | |
) | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
"# 🤗 WIP Dataset ReWriter ✍️✨\n\n" | |
"Adjust, translate or transform completely existing datasets.\n\n" | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
dataset_search = HuggingfaceHubSearch( | |
label="Hub Dataset ID", | |
placeholder="Search for dataset id on Huggingface", | |
search_type="dataset", | |
) | |
subset_dropdown = gr.Dropdown(info="Subset", show_label=False, visible=False) | |
split_dropdown = gr.Dropdown(info="Split", show_label=False, visible=False) | |
gr.Markdown("### Input") | |
input_preview = gr.DataFrame(visible=False) | |
pretty_input_preview = gr.DataFrame(interactive=False, wrap=True) | |
gr.Markdown("### ReWrite") | |
input_prompt = gr.Textbox(label="Enter the adjustment or transformation to apply to the dataset:") | |
with gr.Accordion("Modify Format", open=False): | |
output_format = gr.Textbox(interactive=True, show_label=False, container=False) | |
rewrite_button = gr.Button("ReWrite Dataset", variant="primary") | |
output_preview = gr.DataFrame(interactive=False, wrap=True) | |
save_button = gr.Button("ReWrite Full Dataset", interactive=False) | |
############ | |
# | |
# Utils | |
# | |
########### | |
def stream_rows(dataset: str, subset: str, split: str, batch_size: int = 100) -> Iterable[dict[str, Any]]: | |
for i in count(): | |
rows_resp = session.get(f"https://datasets-server.huggingface.co/rows?dataset={dataset}&config={subset}&split={split}&offset={i * batch_size}&length={batch_size}", timeout=10).json() | |
if "error" in rows_resp: | |
raise RuntimeError(rows_resp["error"]) | |
if not rows_resp["rows"]: | |
break | |
for row_item in rows_resp["rows"]: | |
yield row_item["row"] | |
T = TypeVar("T") | |
def batched(it: Iterable[T], n: int) -> Iterator[list[T]]: | |
it = iter(it) | |
while batch := list(islice(it, n)): | |
yield batch | |
def stream_reponse(messages: list[dict[str: str]], response_format=None) -> Iterator[str]: | |
for _ in range(3): | |
message = None | |
try: | |
for message in client.chat_completion( | |
messages=messages, | |
max_tokens=5000, | |
stream=True, | |
top_p=0.8, | |
seed=42, | |
response_format=response_format | |
): | |
yield message.choices[0].delta.content | |
except requests.exceptions.ConnectionError as e: | |
if message: | |
raise | |
print(e + "\n\nRetrying in 1sec") | |
time.sleep(1) | |
continue | |
break | |
def stream_rewrite_dataset_row_by_row(dataset: str, rows: list[dict[str, str]], prompt: str, format: str) -> Iterator[dict[str, str]]: | |
prompt = prompt[:1000] if prompt.strip() else "" | |
messages = [{"role": "user", "content": REWRITE_DATASET.format( | |
dataset=dataset, | |
rows=json.dumps({"data": rows}), | |
prompt=prompt, | |
)}] | |
response_format = {"type": "json", "value": {"properties": {"data": {"type": "array", "maxItems": len(rows), "minItems": len(rows), "items": format}}, "required": ["data"]}} | |
print("go") | |
yield from islice(ijson.items(StringIteratorIO(stream_reponse(messages, response_format=response_format)), "data.item", buf_size=4), len(rows)) | |
print("done") | |
def _write_generator_to_queue(queue: Queue, func: Callable[..., Iterable], kwargs: dict) -> None: | |
for i, result in enumerate(func(**kwargs)): | |
queue.put(result) | |
return None | |
def iflatmap_unordered( | |
func: Callable[..., Iterable[T]], | |
*, | |
kwargs_iterable: Iterable[dict], | |
) -> Iterable[T]: | |
queue = Queue() | |
with ThreadPool() as pool: | |
async_results = [pool.apply_async(_write_generator_to_queue, (queue, func, kwargs)) for kwargs in kwargs_iterable] | |
try: | |
while True: | |
try: | |
yield queue.get(timeout=0.05) | |
except Empty: | |
if all(async_result.ready() for async_result in async_results) and queue.empty(): | |
break | |
finally: # in case there's an error to raise | |
[async_result.get(timeout=0.05) for async_result in async_results] | |
def features_to_format(features: Features) -> dict: | |
def feature_to_format(feature): | |
if isinstance(feature, Value): | |
if "int" in feature.dtype: | |
return {"type": "integer"} | |
elif "float" in feature.dtype: | |
return {"type": "number"} | |
else: | |
return {"type": "string"} | |
elif isinstance(feature, list): | |
return {"type": "array", "items": feature_to_format(feature[0])} | |
elif isinstance(feature, dict): | |
return {"properties": {k: feature_to_format(v) for k, v in feature.items()}, "required": list(feature)} | |
elif isinstance(feature, Sequence): | |
if isinstance(feature.feature, dict): | |
return {"properties": {k: {"type": "array", "items": v } for k, v in feature_to_format(feature.feature).items()}, "required": list(feature)} | |
else: | |
return {"type": "array", "items": feature_to_format(feature.feature)} | |
else: | |
return {"type": "string"} | |
return feature_to_format(features) | |
############ | |
# | |
# Events | |
# | |
########### | |
def _resolve_dataset_selection(dataset: str, default_subset: str, default_split: str) -> dict: | |
if "/" not in dataset.strip().strip("/"): | |
return None, None, { | |
subset_dropdown: gr.Dropdown(visible=False), | |
split_dropdown: gr.Dropdown(visible=False), | |
} | |
info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=3).json() | |
if "error" in info_resp: | |
return None, None, { | |
subset_dropdown: gr.Dropdown(visible=False), | |
split_dropdown: gr.Dropdown(visible=False), | |
} | |
subsets: list[str] = list(info_resp["dataset_info"]) | |
subset = default_subset if default_subset in subsets else subsets[0] | |
splits: list[str] = info_resp["dataset_info"][subset]["splits"] | |
split = default_split if default_split in splits else splits[0] | |
json_format = json.dumps(features_to_format(Features.from_dict(info_resp["dataset_info"][subset]["features"])), indent=2) | |
return subset, split, { | |
subset_dropdown: gr.Dropdown(value=subset, choices=subsets, visible=len(subsets) > 1), | |
split_dropdown: gr.Dropdown(value=split, choices=splits, visible=len(splits) > 1), | |
output_format: gr.Textbox(json_format, lines=json_format.count("\n") + 1) | |
} | |
def _show_input_preview(dataset: str, default_subset: str, default_split: str) -> dict: | |
subset, split, output = _resolve_dataset_selection(dataset, default_subset=default_subset, default_split=default_split) | |
if subset is None or split is None: | |
return output | |
rows = list(islice((stream_rows(dataset, subset, split, batch_size=NUM_ROWS_PREVIEW)), NUM_ROWS_PREVIEW)) | |
return { | |
input_preview: pd.DataFrame(rows), | |
pretty_input_preview: pd.DataFrame([{k: str(v) for k, v in row.items()} for row in rows]), | |
**output | |
} | |
def show_input_from_dataset_search(dataset: str) -> dict: | |
return _show_input_preview(dataset, default_subset="default", default_split="train") | |
def show_input_from_subset_dropdown(dataset: str, subset: str) -> dict: | |
return _show_input_preview(dataset, default_subset=subset, default_split="train") | |
def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict: | |
return _show_input_preview(dataset, default_subset=subset, default_split=split) | |
def rewrite(dataset: str, subset: str, split: str, input_preview_df: pd.DataFrame, prompt: str, json_format: str) -> Iterator[pd.DataFrame]: | |
rows = input_preview_df.to_dict(orient="records") | |
output_rows = [] | |
for row in stream_rewrite_dataset_row_by_row(dataset=dataset, rows=rows, prompt=prompt, format=json.loads(json_format)): | |
output_rows.append(row) | |
yield pd.DataFrame(output_rows) | |
demo.launch() | |