bicycle_maintenance_config_space / pages /3_🌱 Generate Dataset.py
ignacioct's picture
recommiting all files
8773ff3
raw history blame
No virus
7.51 kB
import streamlit as st
from streamlit.errors import EntryNotFoundError
from hub import pull_seed_data_from_repo, push_pipeline_to_hub
from defaults import (
DEFAULT_SYSTEM_PROMPT,
PIPELINE_PATH,
PROJECT_NAME,
ARGILLA_SPACE_REPO_ID,
DATASET_REPO_ID,
ARGILLA_SPACE_NAME,
ARGILLA_URL,
PROJECT_SPACE_REPO_ID,
HUB_USERNAME,
)
from utils import project_sidebar
from pipeline import serialize_pipeline, run_pipeline, create_pipelines_run_command
st.set_page_config(
page_title="Domain Data Grower",
page_icon="πŸ§‘β€πŸŒΎ",
)
project_sidebar()
################################################################################
# HEADER
################################################################################
st.header("πŸ§‘β€πŸŒΎ Domain Data Grower")
st.divider()
st.subheader("Step 3. Run the pipeline to generate synthetic data")
st.write(
"Define the project details, including the project name, domain, and API credentials"
)
###############################################################
# CONFIGURATION
###############################################################
st.divider()
st.markdown("### Pipeline Configuration")
st.write("πŸ€— Hub details to pull the seed data")
hub_username = st.text_input("Hub Username", HUB_USERNAME)
project_name = st.text_input("Project Name", PROJECT_NAME)
repo_id = f"{hub_username}/{project_name}"
hub_token = st.text_input("Hub Token", type="password")
st.write("πŸ€– Inference configuration")
st.write(
"Add the url of the Huggingface inference API or endpoint that your pipeline should use. You can find compatible models here:"
)
st.link_button(
"πŸ€— Inference compaptible models on the hub",
"https://huggingface.co/models?pipeline_tag=text-generation&other=endpoints_compatible&sort=trending",
)
base_url = st.text_input("Base URL")
st.write("πŸ”¬ Argilla API details to push the generated dataset")
argilla_url = st.text_input("Argilla API URL", ARGILLA_URL)
argilla_api_key = st.text_input("Argilla API Key", "owner.apikey")
argilla_dataset_name = st.text_input("Argilla Dataset Name", project_name)
st.divider()
###############################################################
# LOCAL
###############################################################
st.markdown("### Run the pipeline")
st.write(
"Once you've defined the pipeline configuration, you can run the pipeline locally or on this space."
)
st.write(
"""We recommend running the pipeline locally if you're planning on generating a large dataset. \
But running the pipeline on this space is a handy way to get started quickly. Your synthetic
samples will be pushed to Argilla and available for review.
"""
)
st.write(
"""If you're planning on running the pipeline on the space, be aware that it \
will take some time to complete and you will need to maintain a \
connection to the space."""
)
if st.button("πŸ’» Run pipeline locally", key="run_pipeline_local"):
if all(
[
argilla_api_key,
argilla_url,
base_url,
hub_username,
project_name,
hub_token,
argilla_dataset_name,
]
):
with st.spinner("Pulling seed data from the Hub..."):
seed_data = pull_seed_data_from_repo(
repo_id=f"{hub_username}/{project_name}",
hub_token=hub_token,
)
domain = seed_data["domain"]
perspectives = seed_data["perspectives"]
topics = seed_data["topics"]
examples = seed_data["examples"]
domain_expert_prompt = seed_data["domain_expert_prompt"]
with st.spinner("Serializing the pipeline configuration..."):
serialize_pipeline(
argilla_api_key=argilla_api_key,
argilla_dataset_name=argilla_dataset_name,
argilla_api_url=argilla_url,
topics=topics,
perspectives=perspectives,
pipeline_config_path=PIPELINE_PATH,
domain_expert_prompt=domain_expert_prompt or DEFAULT_SYSTEM_PROMPT,
hub_token=hub_token,
endpoint_base_url=base_url,
examples=examples,
)
push_pipeline_to_hub(
pipeline_path=PIPELINE_PATH,
hub_token=hub_token,
hub_username=hub_username,
project_name=project_name,
)
st.success(f"Pipeline configuration saved to {hub_username}/{project_name}")
st.info(
"To run the pipeline locally, you need to have the `distilabel` library installed. You can install it using the following command:"
)
st.text(
"Execute the following command to generate a synthetic dataset from the seed data:"
)
command_to_run = create_pipelines_run_command(
hub_token=hub_token,
pipeline_config_path=PIPELINE_PATH,
argilla_dataset_name=argilla_dataset_name,
)
st.code(
f"""
pip install git+https://github.com/argilla-io/distilabel.git
git clone https://huggingface.co/{hub_username}/{project_name}
cd {project_name}
{' '.join(command_to_run[2:])}
""",
language="bash",
)
else:
st.error("Please fill all the required fields.")
###############################################################
# SPACE
###############################################################
if st.button("πŸ”₯ Run pipeline right here, right now!"):
if all(
[
argilla_api_key,
argilla_url,
base_url,
hub_username,
project_name,
hub_token,
argilla_dataset_name,
]
):
with st.spinner("Pulling seed data from the Hub..."):
try:
seed_data = pull_seed_data_from_repo(
repo_id=f"{hub_username}/{project_name}",
hub_token=hub_token,
)
except EntryNotFoundError:
st.error(
"Seed data not found. Please make sure you pushed the data seed in Step 2."
)
domain = seed_data["domain"]
perspectives = seed_data["perspectives"]
topics = seed_data["topics"]
examples = seed_data["examples"]
domain_expert_prompt = seed_data["domain_expert_prompt"]
with st.spinner("Serializing the pipeline configuration..."):
serialize_pipeline(
argilla_api_key=argilla_api_key,
argilla_dataset_name=argilla_dataset_name,
argilla_api_url=argilla_url,
topics=topics,
perspectives=perspectives,
pipeline_config_path=PIPELINE_PATH,
domain_expert_prompt=domain_expert_prompt or DEFAULT_SYSTEM_PROMPT,
hub_token=hub_token,
endpoint_base_url=base_url,
examples=examples,
)
with st.spinner("Starting the pipeline..."):
logs = run_pipeline(PIPELINE_PATH)
st.success(f"Pipeline started successfully! πŸš€")
with st.expander(label="View Logs", expanded=True):
for out in logs:
st.text(out)
else:
st.error("Please fill all the required fields.")