|
from pathlib import Path |
|
from urllib.parse import urlparse, parse_qs |
|
|
|
import gradio as gr |
|
import io |
|
import pandas as pd |
|
import spaces |
|
|
|
from generate import model_id, stream_jsonl_file |
|
|
|
MAX_SIZE = 20 |
|
DEFAULT_SEED = 42 |
|
DEFAULT_SIZE = 3 |
|
|
|
@spaces.GPU(duration=120) |
|
def stream_output(query: str, continue_content: str = ""): |
|
query = Path(query).name |
|
parsed_filename = urlparse(query) |
|
filename = parsed_filename.path |
|
params = parse_qs(parsed_filename.query) |
|
prompt = params["prompt"][0] if "prompt" in params else "" |
|
columns = [column.strip() for column in params["columns"][0].split(",") if column.strip()] if "columns" in params else [] |
|
size = int(params["size"][0]) if "size" in params else DEFAULT_SIZE |
|
seed = int(params["seed"][0]) if "seed" in params else DEFAULT_SEED |
|
if size > MAX_SIZE: |
|
raise gr.Error(f"Maximum size is {MAX_SIZE}. Duplicate this Space to remove this limit.") |
|
content = continue_content |
|
df = pd.read_json(io.StringIO(content), lines=True, convert_dates=False) |
|
continue_content_size = len(df) |
|
state_msg = f"⚙️ Generating... [{continue_content_size + 1}/{continue_content_size + size}]" |
|
if list(df.columns): |
|
columns = list(df.columns) |
|
else: |
|
df = pd.DataFrame({"1": [], "2": [], "3": []}) |
|
yield df, "```json\n" + content + "\n```", gr.Button(state_msg), gr.Button("Generate one more batch", interactive=False), gr.DownloadButton("⬇️ Download", interactive=False) |
|
for i, chunk in enumerate(stream_jsonl_file( |
|
filename=filename, |
|
prompt=prompt, |
|
columns=columns, |
|
seed=seed + (continue_content_size // size), |
|
size=size, |
|
)): |
|
content += chunk |
|
df = pd.read_json(io.StringIO(content), lines=True, convert_dates=False) |
|
state_msg = f"⚙️ Generating... [{continue_content_size + i + 1}/{continue_content_size + size}]" |
|
yield df, "```json\n" + content + "\n```", gr.Button(state_msg), gr.Button("Generate one more batch", interactive=False), gr.DownloadButton("⬇️ Download", interactive=False) |
|
with open(query, "w", encoding="utf-8") as f: |
|
f.write(content) |
|
yield df, "```json\n" + content + "\n```", gr.Button("Generate dataset"), gr.Button("Generate one more batch", visible=True, interactive=True), gr.DownloadButton("⬇️ Download", value=query, visible=True, interactive=True) |
|
|
|
|
|
def stream_more_output(query: str): |
|
query = Path(query).name |
|
with open(query, "r", encoding="utf-8") as f: |
|
continue_content = f.read() |
|
yield from stream_output(query=query, continue_content=continue_content) |
|
|
|
|
|
title = "LLM DataGen" |
|
description = ( |
|
f"Generate and stream synthetic dataset files in `{{JSON Lines}}` format (currently using [{model_id}](https://huggingface.co/{model_id}))\n\n" |
|
"Disclaimer: LLM data generation is an area of active research with known problems such as biased generation and incorrect information." |
|
) |
|
examples = [ |
|
"movies_data.jsonl", |
|
"dungeon_and_dragon_characters.jsonl", |
|
"bad_amazon_reviews_on_defunct_products_that_people_hate.jsonl", |
|
"common_first_names.jsonl?columns=first_name,popularity&size=10", |
|
] |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(f"# {title}") |
|
gr.Markdown(description) |
|
filename_comp = gr.Textbox(examples[0], placeholder=examples[0], label="File name to generate") |
|
outputs = [] |
|
generate_button = gr.Button("Generate dataset") |
|
with gr.Tab("Dataset"): |
|
dataframe_comp = gr.DataFrame() |
|
with gr.Tab("File content"): |
|
file_content_comp = gr.Markdown() |
|
with gr.Row(): |
|
generate_more_button = gr.Button("Generate one more batch", visible=False, interactive=False, scale=3) |
|
download_button = gr.DownloadButton("⬇️ Download", visible=False, interactive=False, scale=1) |
|
outputs = [dataframe_comp, file_content_comp, generate_button, generate_more_button, download_button] |
|
examples = gr.Examples(examples, filename_comp, outputs, fn=stream_output, run_on_click=True) |
|
generate_button.click(stream_output, filename_comp, outputs) |
|
generate_more_button.click(stream_more_output, filename_comp, outputs) |
|
|
|
|
|
demo.launch() |