Spaces:
Runtime error
Runtime error
import os | |
import subprocess | |
import time | |
from typing import List | |
from distilabel.steps.generators.data import LoadDataFromDicts | |
from distilabel.steps.expand import ExpandColumns | |
from distilabel.steps.keep import KeepColumns | |
from distilabel.steps.tasks.self_instruct import SelfInstruct | |
from distilabel.steps.tasks.evol_instruct.base import EvolInstruct | |
from distilabel.llms.huggingface import InferenceEndpointsLLM | |
from distilabel.pipeline import Pipeline | |
from distilabel.steps import TextGenerationToArgilla | |
from dotenv import load_dotenv | |
from domain import ( | |
DomainExpert, | |
CleanNumberedList, | |
create_topics, | |
create_examples_template, | |
APPLICATION_DESCRIPTION, | |
) | |
load_dotenv() | |
def define_pipeline( | |
argilla_api_key: str, | |
argilla_api_url: str, | |
argilla_dataset_name: str, | |
topics: List[str], | |
perspectives: List[str], | |
domain_expert_prompt: str, | |
examples: List[dict], | |
hub_token: str, | |
endpoint_base_url: str, | |
): | |
"""Define the pipeline for the specific domain.""" | |
terms = create_topics(topics, perspectives) | |
template = create_examples_template(examples) | |
with Pipeline("farming") as pipeline: | |
load_data = LoadDataFromDicts( | |
name="load_data", | |
data=[{"input": term} for term in terms], | |
batch_size=64, | |
) | |
llm = InferenceEndpointsLLM( | |
base_url=endpoint_base_url, | |
api_key=hub_token, | |
) | |
self_instruct = SelfInstruct( | |
name="self-instruct", | |
application_description=APPLICATION_DESCRIPTION, | |
num_instructions=5, | |
input_batch_size=8, | |
llm=llm, | |
) | |
evol_instruction_complexity = EvolInstruct( | |
name="evol_instruction_complexity", | |
llm=llm, | |
num_evolutions=2, | |
store_evolutions=True, | |
input_batch_size=8, | |
include_original_instruction=True, | |
input_mappings={"instruction": "question"}, | |
) | |
expand_instructions = ExpandColumns( | |
name="expand_columns", columns={"instructions": "question"} | |
) | |
cleaner = CleanNumberedList(name="clean_numbered_list") | |
expand_evolutions = ExpandColumns( | |
name="expand_columns_evolved", | |
columns={"evolved_instructions": "evolved_questions"}, | |
) | |
domain_expert = DomainExpert( | |
name="domain_expert", | |
llm=llm, | |
input_batch_size=8, | |
input_mappings={"instruction": "evolved_questions"}, | |
output_mappings={"generation": "domain_expert_answer"}, | |
_system_prompt=domain_expert_prompt, | |
_template=template, | |
) | |
keep_columns = KeepColumns( | |
name="keep_columns", | |
columns=["model_name", "evolved_questions", "domain_expert_answer"], | |
) | |
to_argilla = TextGenerationToArgilla( | |
name="text_generation_to_argilla", | |
dataset_name=argilla_dataset_name, | |
dataset_workspace="admin", | |
api_url=argilla_api_url, | |
api_key=argilla_api_key, | |
input_mappings={ | |
"instruction": "evolved_questions", | |
"generation": "domain_expert_answer", | |
}, | |
) | |
load_data.connect(self_instruct) | |
self_instruct.connect(expand_instructions) | |
expand_instructions.connect(cleaner) | |
cleaner.connect(evol_instruction_complexity) | |
evol_instruction_complexity.connect(expand_evolutions) | |
expand_evolutions.connect(domain_expert) | |
domain_expert.connect(keep_columns) | |
keep_columns.connect(to_argilla) | |
return pipeline | |
def serialize_pipeline( | |
argilla_api_key: str, | |
argilla_api_url: str, | |
argilla_dataset_name: str, | |
topics: List[str], | |
perspectives: List[str], | |
domain_expert_prompt: str, | |
hub_token: str, | |
endpoint_base_url: str, | |
pipeline_config_path: str = "pipeline.yaml", | |
examples: List[dict] = [], | |
): | |
"""Serialize the pipeline to a yaml file.""" | |
pipeline = define_pipeline( | |
argilla_api_key=argilla_api_key, | |
argilla_api_url=argilla_api_url, | |
argilla_dataset_name=argilla_dataset_name, | |
topics=topics, | |
perspectives=perspectives, | |
domain_expert_prompt=domain_expert_prompt, | |
hub_token=hub_token, | |
endpoint_base_url=endpoint_base_url, | |
examples=examples, | |
) | |
pipeline.save(path=pipeline_config_path, overwrite=True, format="yaml") | |
def create_pipelines_run_command( | |
pipeline_config_path: str = "pipeline.yaml", | |
argilla_dataset_name: str = "domain_specific_datasets", | |
): | |
"""Create the command to run the pipeline.""" | |
command_to_run = [ | |
"python", | |
"-m", | |
"distilabel", | |
"pipeline", | |
"run", | |
"--config", | |
pipeline_config_path, | |
"--param", | |
f"text_generation_to_argilla.dataset_name={argilla_dataset_name}", | |
] | |
return command_to_run | |
def run_pipeline( | |
pipeline_config_path: str = "pipeline.yaml", | |
argilla_dataset_name: str = "domain_specific_datasets", | |
): | |
"""Run the pipeline and yield the output as a generator of logs.""" | |
command_to_run = create_pipelines_run_command( | |
pipeline_config_path=pipeline_config_path, | |
argilla_dataset_name=argilla_dataset_name, | |
) | |
# Run the script file | |
process = subprocess.Popen( | |
command_to_run, stdout=subprocess.PIPE, stderr=subprocess.PIPE | |
) | |
while process.stdout and process.stdout.readable(): | |
time.sleep(0.2) | |
line = process.stdout.readline() | |
if not line: | |
break | |
yield line.decode("utf-8") | |