burtenshaw HF staff commited on
Commit
26fb24b
β€’
1 Parent(s): 63a8770

Upload 9 files

Browse files
defaults.py CHANGED
@@ -3,7 +3,7 @@ import json
3
 
4
  SEED_DATA_PATH = "seed_data.json"
5
  PIPELINE_PATH = "pipeline.yaml"
6
- REMOTE_CODE_PATHS = ["defaults.py", "domain.py", "pipeline.py", "requirements.txt"]
7
  DIBT_PARENT_APP_URL = "https://argilla-domain-specific-datasets-welcome.hf.space/"
8
  N_PERSPECTIVES = 5
9
  N_TOPICS = 5
 
3
 
4
  SEED_DATA_PATH = "seed_data.json"
5
  PIPELINE_PATH = "pipeline.yaml"
6
+ REMOTE_CODE_PATHS = ["requirements.txt"]
7
  DIBT_PARENT_APP_URL = "https://argilla-domain-specific-datasets-welcome.hf.space/"
8
  N_PERSPECTIVES = 5
9
  N_TOPICS = 5
hub.py CHANGED
@@ -94,7 +94,7 @@ def push_pipeline_to_hub(
94
  # upload the pipeline to the hub
95
  hf_api.upload_file(
96
  path_or_fileobj=pipeline_path,
97
- path_in_repo="pipeline.yaml",
98
  token=hub_token,
99
  repo_id=repo_id,
100
  repo_type="dataset",
@@ -115,7 +115,7 @@ def push_pipeline_to_hub(
115
  def pull_seed_data_from_repo(repo_id, hub_token):
116
  # pull the dataset repo from the hub
117
  hf_api.hf_hub_download(
118
- repo_id=repo_id, token=hub_token, repo_type="dataset", filename=SEED_DATA_PATH, force_download=True
119
  )
120
  return json.load(open(SEED_DATA_PATH))
121
 
@@ -127,3 +127,25 @@ def push_argilla_dataset_to_hub(
127
  feedback_dataset = rg.FeedbackDataset.from_argilla(name=name, workspace=workspace)
128
  local_dataset = feedback_dataset.pull()
129
  local_dataset.push_to_huggingface(repo_id=repo_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  # upload the pipeline to the hub
95
  hf_api.upload_file(
96
  path_or_fileobj=pipeline_path,
97
+ path_in_repo="pipeline.py",
98
  token=hub_token,
99
  repo_id=repo_id,
100
  repo_type="dataset",
 
115
  def pull_seed_data_from_repo(repo_id, hub_token):
116
  # pull the dataset repo from the hub
117
  hf_api.hf_hub_download(
118
+ repo_id=repo_id, token=hub_token, repo_type="dataset", filename=SEED_DATA_PATH
119
  )
120
  return json.load(open(SEED_DATA_PATH))
121
 
 
127
  feedback_dataset = rg.FeedbackDataset.from_argilla(name=name, workspace=workspace)
128
  local_dataset = feedback_dataset.pull()
129
  local_dataset.push_to_huggingface(repo_id=repo_id)
130
+
131
+
132
+ def push_pipeline_params(
133
+ pipeline_params,
134
+ hub_username,
135
+ hub_token: str,
136
+ project_name,
137
+ ):
138
+ repo_id = f"{hub_username}/{project_name}"
139
+ temp_path = mktemp()
140
+ with open(temp_path, "w") as f:
141
+ json.dump(pipeline_params, f)
142
+ # upload the pipeline to the hub
143
+ hf_api.upload_file(
144
+ path_or_fileobj=temp_path,
145
+ path_in_repo="pipeline_params.json",
146
+ token=hub_token,
147
+ repo_id=repo_id,
148
+ repo_type="dataset",
149
+ )
150
+
151
+ print(f"Pipeline params uploaded to {repo_id}")
pages/2_πŸ‘©πŸΌβ€πŸ”¬ Describe Domain.py CHANGED
@@ -2,14 +2,9 @@ import json
2
 
3
  import streamlit as st
4
 
5
- from hub import push_dataset_to_hub
6
  from infer import query
7
  from defaults import (
8
- DEFAULT_DOMAIN,
9
- DEFAULT_PERSPECTIVES,
10
- DEFAULT_TOPICS,
11
- DEFAULT_EXAMPLES,
12
- DEFAULT_SYSTEM_PROMPT,
13
  N_PERSPECTIVES,
14
  N_TOPICS,
15
  SEED_DATA_PATH,
@@ -18,12 +13,14 @@ from defaults import (
18
  )
19
  from utils import project_sidebar
20
 
 
21
  st.set_page_config(
22
  page_title="Domain Data Grower",
23
  page_icon="πŸ§‘β€πŸŒΎ",
24
  )
25
  project_sidebar()
26
 
 
27
  ################################################################################
28
  # HEADER
29
  ################################################################################
@@ -37,6 +34,23 @@ st.write(
37
  "Define the project details, including the project name, domain, and API credentials"
38
  )
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  ################################################################################
41
  # Domain Expert Section
42
  ################################################################################
@@ -212,22 +226,6 @@ with tab_raw_seed:
212
 
213
  st.divider()
214
 
215
- hub_username = DATASET_REPO_ID.split("/")[0]
216
- project_name = DATASET_REPO_ID.split("/")[1]
217
- st.write("Define the dataset repo details on the Hub")
218
- st.session_state["project_name"] = st.text_input("Project Name", project_name)
219
- st.session_state["hub_username"] = st.text_input("Hub Username", hub_username)
220
- st.session_state["hub_token"] = st.text_input("Hub Token", type="password", value=None)
221
-
222
- if all(
223
- (
224
- st.session_state.get("project_name"),
225
- st.session_state.get("hub_username"),
226
- st.session_state.get("hub_token"),
227
- )
228
- ):
229
- st.success(f"Using the dataset repo {hub_username}/{project_name} on the Hub")
230
-
231
 
232
  if st.button("πŸ€— Push Dataset Seed") and all(
233
  (
 
2
 
3
  import streamlit as st
4
 
5
+ from hub import push_dataset_to_hub, pull_seed_data_from_repo
6
  from infer import query
7
  from defaults import (
 
 
 
 
 
8
  N_PERSPECTIVES,
9
  N_TOPICS,
10
  SEED_DATA_PATH,
 
13
  )
14
  from utils import project_sidebar
15
 
16
+
17
  st.set_page_config(
18
  page_title="Domain Data Grower",
19
  page_icon="πŸ§‘β€πŸŒΎ",
20
  )
21
  project_sidebar()
22
 
23
+
24
  ################################################################################
25
  # HEADER
26
  ################################################################################
 
34
  "Define the project details, including the project name, domain, and API credentials"
35
  )
36
 
37
+
38
+ ################################################################################
39
+ # LOAD EXISTING DOMAIN DATA
40
+ ################################################################################
41
+
42
+ DATASET_REPO_ID = (
43
+ f"{st.session_state['hub_username']}/{st.session_state['project_name']}"
44
+ )
45
+ SEED_DATA = pull_seed_data_from_repo(
46
+ DATASET_REPO_ID, hub_token=st.session_state["hub_token"]
47
+ )
48
+ DEFAULT_DOMAIN = SEED_DATA.get("domain", "")
49
+ DEFAULT_PERSPECTIVES = SEED_DATA.get("perspectives", [""])
50
+ DEFAULT_TOPICS = SEED_DATA.get("topics", [""])
51
+ DEFAULT_EXAMPLES = SEED_DATA.get("examples", [{"question": "", "answer": ""}])
52
+ DEFAULT_SYSTEM_PROMPT = SEED_DATA.get("domain_expert_prompt", "")
53
+
54
  ################################################################################
55
  # Domain Expert Section
56
  ################################################################################
 
226
 
227
  st.divider()
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  if st.button("πŸ€— Push Dataset Seed") and all(
231
  (
pages/3_🌱 Generate Dataset.py CHANGED
@@ -1,18 +1,9 @@
1
  import streamlit as st
2
 
3
- from hub import pull_seed_data_from_repo, push_pipeline_to_hub
4
- from defaults import (
5
- DEFAULT_SYSTEM_PROMPT,
6
- PIPELINE_PATH,
7
- PROJECT_NAME,
8
- ARGILLA_URL,
9
- HUB_USERNAME,
10
- CODELESS_DISTILABEL,
11
- )
12
  from utils import project_sidebar
13
 
14
- from pipeline import serialize_pipeline, run_pipeline, create_pipelines_run_command
15
-
16
  st.set_page_config(
17
  page_title="Domain Data Grower",
18
  page_icon="πŸ§‘β€πŸŒΎ",
@@ -27,20 +18,15 @@ project_sidebar()
27
  st.header("πŸ§‘β€πŸŒΎ Domain Data Grower")
28
  st.divider()
29
  st.subheader("Step 3. Run the pipeline to generate synthetic data")
30
- st.write("Define the project repos and models that the pipeline will use.")
31
 
32
- st.divider()
33
  ###############################################################
34
  # CONFIGURATION
35
  ###############################################################
36
 
37
- st.markdown("## Pipeline Configuration")
38
-
39
- st.markdown("#### πŸ€— Hub details to pull the seed data")
40
- hub_username = st.text_input("Hub Username", HUB_USERNAME)
41
- project_name = st.text_input("Project Name", PROJECT_NAME)
42
- repo_id = f"{hub_username}/{project_name}"
43
- hub_token = st.text_input("Hub Token", type="password")
44
 
45
  st.divider()
46
 
@@ -89,169 +75,74 @@ st.divider()
89
 
90
  st.markdown("## Run the pipeline")
91
 
92
- st.write(
93
- "Once you've defined the pipeline configuration, you can run the pipeline from your local machine."
94
  )
95
 
96
- if CODELESS_DISTILABEL:
97
- st.write(
98
- """We recommend running the pipeline locally if you're planning on generating a large dataset. \
99
- But running the pipeline on this space is a handy way to get started quickly. Your synthetic
100
- samples will be pushed to Argilla and available for review.
101
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  )
103
- st.write(
104
- """If you're planning on running the pipeline on the space, be aware that it \
105
- will take some time to complete and you will need to maintain a \
106
- connection to the space."""
 
 
107
  )
108
 
 
 
 
109
 
110
- if st.button("πŸ’» Run pipeline locally", key="run_pipeline_local"):
111
- if all(
112
- [
113
- argilla_api_key,
114
- argilla_url,
115
- base_url,
116
- hub_username,
117
- project_name,
118
- hub_token,
119
- argilla_dataset_name,
120
- ]
121
- ):
122
- with st.spinner("Pulling seed data from the Hub..."):
123
- try:
124
- seed_data = pull_seed_data_from_repo(
125
- repo_id=f"{hub_username}/{project_name}",
126
- hub_token=hub_token,
127
- )
128
- except Exception:
129
- st.error(
130
- "Seed data not found. Please make sure you pushed the data seed in Step 2."
131
- )
132
-
133
- domain = seed_data["domain"]
134
- perspectives = seed_data["perspectives"]
135
- topics = seed_data["topics"]
136
- examples = seed_data["examples"]
137
- domain_expert_prompt = seed_data["domain_expert_prompt"]
138
-
139
- with st.spinner("Serializing the pipeline configuration..."):
140
- serialize_pipeline(
141
- argilla_api_key=argilla_api_key,
142
- argilla_dataset_name=argilla_dataset_name,
143
- argilla_api_url=argilla_url,
144
- topics=topics,
145
- perspectives=perspectives,
146
- pipeline_config_path=PIPELINE_PATH,
147
- domain_expert_prompt=domain_expert_prompt or DEFAULT_SYSTEM_PROMPT,
148
- hub_token=hub_token,
149
- endpoint_base_url=base_url,
150
- examples=examples,
151
- )
152
- push_pipeline_to_hub(
153
- pipeline_path=PIPELINE_PATH,
154
- hub_token=hub_token,
155
- hub_username=hub_username,
156
- project_name=project_name,
157
- )
158
-
159
- st.success(f"Pipeline configuration saved to {hub_username}/{project_name}")
160
-
161
- st.info(
162
- "To run the pipeline locally, you need to have the `distilabel` library installed. You can install it using the following command:"
163
- )
164
- st.text(
165
- "Execute the following command to generate a synthetic dataset from the seed data:"
166
- )
167
- command_to_run = create_pipelines_run_command(
168
- hub_token=hub_token,
169
- pipeline_config_path=PIPELINE_PATH,
170
- argilla_dataset_name=argilla_dataset_name,
171
- argilla_api_key=argilla_api_key,
172
- argilla_api_url=argilla_url,
173
- )
174
- st.code(
175
- f"""
176
- pip install git+https://github.com/argilla-io/distilabel.git
177
- git clone https://huggingface.co/datasets/{hub_username}/{project_name}
178
- cd {project_name}
179
- pip install -r requirements.txt
180
- {' '.join(["python"] + command_to_run[1:])}
181
- """,
182
- language="bash",
183
- )
184
- st.subheader(
185
- "πŸ‘©β€πŸš€ If you want to access the pipeline and manipulate the locally, you can do:"
186
- )
187
- st.code(
188
- """
189
- git clone https://github.com/huggingface/data-is-better-together
190
- cd domain-specific-datasets
191
- """
192
- )
193
- else:
194
- st.error("Please fill all the required fields.")
195
-
196
- ###############################################################
197
- # SPACE
198
- ###############################################################
199
- if CODELESS_DISTILABEL:
200
- if st.button("πŸ”₯ Run pipeline right here, right now!"):
201
- if all(
202
- [
203
- argilla_api_key,
204
- argilla_url,
205
- base_url,
206
- hub_username,
207
- project_name,
208
- hub_token,
209
- argilla_dataset_name,
210
- ]
211
- ):
212
- with st.spinner("Pulling seed data from the Hub..."):
213
- try:
214
- seed_data = pull_seed_data_from_repo(
215
- repo_id=f"{hub_username}/{project_name}",
216
- hub_token=hub_token,
217
- )
218
- except Exception as e:
219
- st.error(
220
- "Seed data not found. Please make sure you pushed the data seed in Step 2."
221
- )
222
 
223
- domain = seed_data["domain"]
224
- perspectives = seed_data["perspectives"]
225
- topics = seed_data["topics"]
226
- examples = seed_data["examples"]
227
- domain_expert_prompt = seed_data["domain_expert_prompt"]
228
 
229
- serialize_pipeline(
230
- argilla_api_key=argilla_api_key,
231
- argilla_dataset_name=argilla_dataset_name,
232
- argilla_api_url=argilla_url,
233
- topics=topics,
234
- perspectives=perspectives,
235
- pipeline_config_path=PIPELINE_PATH,
236
- domain_expert_prompt=domain_expert_prompt or DEFAULT_SYSTEM_PROMPT,
237
- hub_token=hub_token,
238
- endpoint_base_url=base_url,
239
- examples=examples,
240
- )
241
 
242
- with st.spinner("Starting the pipeline..."):
243
- logs = run_pipeline(
244
- pipeline_config_path=PIPELINE_PATH,
245
- argilla_api_key=argilla_api_key,
246
- argilla_api_url=argilla_url,
247
- hub_token=hub_token,
248
- argilla_dataset_name=argilla_dataset_name,
249
- )
250
 
251
- st.success(f"Pipeline started successfully! πŸš€")
 
 
 
 
 
 
 
 
252
 
253
- with st.expander(label="View Logs", expanded=True):
254
- for out in logs:
255
- st.text(out)
256
- else:
257
- st.error("Please fill all the required fields.")
 
1
  import streamlit as st
2
 
3
+ from defaults import ARGILLA_URL
4
+ from hub import push_pipeline_params, push_pipeline_to_hub
 
 
 
 
 
 
 
5
  from utils import project_sidebar
6
 
 
 
7
  st.set_page_config(
8
  page_title="Domain Data Grower",
9
  page_icon="πŸ§‘β€πŸŒΎ",
 
18
  st.header("πŸ§‘β€πŸŒΎ Domain Data Grower")
19
  st.divider()
20
  st.subheader("Step 3. Run the pipeline to generate synthetic data")
21
+ st.write("Define the distilabel pipeline for generating the dataset.")
22
 
 
23
  ###############################################################
24
  # CONFIGURATION
25
  ###############################################################
26
 
27
+ hub_username = st.session_state.get("hub_username")
28
+ project_name = st.session_state.get("project_name")
29
+ hub_token = st.session_state.get("hub_token")
 
 
 
 
30
 
31
  st.divider()
32
 
 
75
 
76
  st.markdown("## Run the pipeline")
77
 
78
+ st.markdown(
79
+ "Once you've defined the pipeline configuration above, you can run the pipeline from your local machine."
80
  )
81
 
82
+
83
+ if all(
84
+ [
85
+ argilla_api_key,
86
+ argilla_url,
87
+ base_url,
88
+ hub_token,
89
+ project_name,
90
+ hub_token,
91
+ argilla_dataset_name,
92
+ ]
93
+ ):
94
+ push_pipeline_params(
95
+ pipeline_params={
96
+ "argilla_api_key": argilla_api_key,
97
+ "argilla_api_url": argilla_url,
98
+ "argilla_dataset_name": argilla_dataset_name,
99
+ "endpoint_base_url": base_url,
100
+ },
101
+ hub_username=hub_username,
102
+ hub_token=hub_token,
103
+ project_name=project_name,
104
  )
105
+
106
+ push_pipeline_to_hub(
107
+ pipeline_path="pipeline.py",
108
+ hub_username=hub_username,
109
+ hub_token=hub_token,
110
+ project_name=project_name,
111
  )
112
 
113
+ st.markdown(
114
+ "To run the pipeline locally, you need to have the `distilabel` library installed. You can install it using the following command:"
115
+ )
116
 
117
+ st.code(
118
+ f"""
119
+
120
+ # Install the distilabel library
121
+ pip install git+https://github.com/argilla-io/distilabel.git
122
+ """
123
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ st.markdown("Next, you'll need to clone your dataset repo and run the pipeline:")
 
 
 
 
126
 
127
+ st.code(
128
+ f"""
129
+ git clone https://huggingface.co/datasets/{hub_username}/{project_name}
130
+ cd {project_name}
131
+ pip install -r requirements.txt
132
+ """
133
+ )
 
 
 
 
 
134
 
135
+ st.markdown("Finally, you can run the pipeline using the following command:")
 
 
 
 
 
 
 
136
 
137
+ st.code(
138
+ """
139
+ huggingface-cli login
140
+ python pipeline.py""",
141
+ language="bash",
142
+ )
143
+ st.markdown(
144
+ "πŸ‘©β€πŸš€ If you want to customise the pipeline take a look in `pipeline.py` and teh [distilabel docs](https://distilabel.argilla.io/)"
145
+ )
146
 
147
+ else:
148
+ st.info("Please fill all the required fields.")
 
 
 
pipeline.py CHANGED
@@ -1,95 +1,142 @@
1
- import subprocess
2
- import sys
3
- import time
4
- from typing import List
5
 
6
- from distilabel.steps.generators.data import LoadDataFromDicts
7
- from distilabel.steps.expand import ExpandColumns
8
- from distilabel.steps.keep import KeepColumns
9
- from distilabel.steps.tasks.self_instruct import SelfInstruct
10
- from distilabel.steps.tasks.evol_instruct.base import EvolInstruct
11
  from distilabel.llms.huggingface import InferenceEndpointsLLM
12
  from distilabel.pipeline import Pipeline
13
  from distilabel.steps import TextGenerationToArgilla
14
- from dotenv import load_dotenv
15
-
16
- from domain import (
17
- DomainExpert,
18
- CleanNumberedList,
19
- create_topics,
20
- create_examples_template,
21
- APPLICATION_DESCRIPTION,
22
- )
23
-
24
- load_dotenv()
25
-
26
-
27
- def define_pipeline(
28
- argilla_api_key: str,
29
- argilla_api_url: str,
30
- argilla_dataset_name: str,
31
- topics: List[str],
32
- perspectives: List[str],
33
- domain_expert_prompt: str,
34
- examples: List[dict],
35
- hub_token: str,
36
- endpoint_base_url: str,
37
- ):
38
- """Define the pipeline for the specific domain."""
39
-
40
- terms = create_topics(topics, perspectives)
41
- template = create_examples_template(examples)
42
- with Pipeline("farming") as pipeline:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  load_data = LoadDataFromDicts(
44
  name="load_data",
45
  data=[{"input": term} for term in terms],
46
  batch_size=64,
47
  )
48
- llm = InferenceEndpointsLLM(
49
- base_url=endpoint_base_url,
50
- api_key=hub_token,
51
- )
52
  self_instruct = SelfInstruct(
53
- name="self-instruct",
54
- application_description=APPLICATION_DESCRIPTION,
55
  num_instructions=5,
56
  input_batch_size=8,
57
- llm=llm,
58
- )
59
-
60
- evol_instruction_complexity = EvolInstruct(
61
- name="evol_instruction_complexity",
62
- llm=llm,
63
- num_evolutions=2,
64
- store_evolutions=True,
65
- input_batch_size=8,
66
- include_original_instruction=True,
67
- input_mappings={"instruction": "question"},
68
  )
69
 
70
  expand_instructions = ExpandColumns(
71
- name="expand_columns", columns={"instructions": "question"}
72
- )
73
- cleaner = CleanNumberedList(name="clean_numbered_list")
74
- expand_evolutions = ExpandColumns(
75
- name="expand_columns_evolved",
76
- columns={"evolved_instructions": "evolved_questions"},
77
  )
78
 
79
  domain_expert = DomainExpert(
80
  name="domain_expert",
81
- llm=llm,
 
 
 
82
  input_batch_size=8,
83
- input_mappings={"instruction": "evolved_questions"},
84
- output_mappings={"generation": "domain_expert_answer"},
85
- )
86
-
87
- domain_expert._system_prompt = domain_expert_prompt
88
- domain_expert._template = template
89
-
90
- keep_columns = KeepColumns(
91
- name="keep_columns",
92
- columns=["model_name", "evolved_questions", "domain_expert_answer"],
93
  )
94
 
95
  to_argilla = TextGenerationToArgilla(
@@ -98,111 +145,30 @@ def define_pipeline(
98
  dataset_workspace="admin",
99
  api_url=argilla_api_url,
100
  api_key=argilla_api_key,
101
- input_mappings={
102
- "instruction": "evolved_questions",
103
- "generation": "domain_expert_answer",
104
- },
105
  )
106
 
 
 
107
  load_data.connect(self_instruct)
108
  self_instruct.connect(expand_instructions)
109
- expand_instructions.connect(cleaner)
110
- cleaner.connect(evol_instruction_complexity)
111
- evol_instruction_complexity.connect(expand_evolutions)
112
- expand_evolutions.connect(domain_expert)
113
- domain_expert.connect(keep_columns)
114
- keep_columns.connect(to_argilla)
115
- return pipeline
116
-
117
-
118
- def serialize_pipeline(
119
- argilla_api_key: str,
120
- argilla_api_url: str,
121
- argilla_dataset_name: str,
122
- topics: List[str],
123
- perspectives: List[str],
124
- domain_expert_prompt: str,
125
- hub_token: str,
126
- endpoint_base_url: str,
127
- pipeline_config_path: str = "pipeline.yaml",
128
- examples: List[dict] = [],
129
- ):
130
- """Serialize the pipeline to a yaml file."""
131
- pipeline = define_pipeline(
132
- argilla_api_key=argilla_api_key,
133
- argilla_api_url=argilla_api_url,
134
- argilla_dataset_name=argilla_dataset_name,
135
- topics=topics,
136
- perspectives=perspectives,
137
- domain_expert_prompt=domain_expert_prompt,
138
- hub_token=hub_token,
139
- endpoint_base_url=endpoint_base_url,
140
- examples=examples,
141
- )
142
- pipeline.save(path=pipeline_config_path, overwrite=True, format="yaml")
143
-
144
-
145
- def create_pipelines_run_command(
146
- hub_token: str,
147
- argilla_api_key: str,
148
- argilla_api_url: str,
149
- pipeline_config_path: str = "pipeline.yaml",
150
- argilla_dataset_name: str = "domain_specific_datasets",
151
- ):
152
- """Create the command to run the pipeline."""
153
- command_to_run = [
154
- sys.executable,
155
- "-m",
156
- "distilabel",
157
- "pipeline",
158
- "run",
159
- "--config",
160
- pipeline_config_path,
161
- "--param",
162
- f"text_generation_to_argilla.dataset_name={argilla_dataset_name}",
163
- "--param",
164
- f"text_generation_to_argilla.api_key={argilla_api_key}",
165
- "--param",
166
- f"text_generation_to_argilla.api_url={argilla_api_url}",
167
- "--param",
168
- f"self-instruct.llm.api_key={hub_token}",
169
- "--param",
170
- f"evol_instruction_complexity.llm.api_key={hub_token}",
171
- "--param",
172
- f"domain_expert.llm.api_key={hub_token}",
173
- "--ignore-cache",
174
- ]
175
- return command_to_run
176
-
177
-
178
- def run_pipeline(
179
- hub_token: str,
180
- argilla_api_key: str,
181
- argilla_api_url: str,
182
- pipeline_config_path: str = "pipeline.yaml",
183
- argilla_dataset_name: str = "domain_specific_datasets",
184
- ):
185
- """Run the pipeline and yield the output as a generator of logs."""
186
-
187
- command_to_run = create_pipelines_run_command(
188
- hub_token=hub_token,
189
- pipeline_config_path=pipeline_config_path,
190
- argilla_dataset_name=argilla_dataset_name,
191
- argilla_api_key=argilla_api_key,
192
- argilla_api_url=argilla_api_url,
193
- )
194
 
195
- # Run the script file
196
- process = subprocess.Popen(
197
- args=command_to_run,
198
- stdout=subprocess.PIPE,
199
- stderr=subprocess.PIPE,
200
- env={"HF_TOKEN": hub_token},
201
- )
202
 
203
- while process.stdout and process.stdout.readable():
204
- time.sleep(0.2)
205
- line = process.stdout.readline()
206
- if not line:
207
- break
208
- yield line.decode("utf-8")
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from textwrap import dedent
3
+ from typing import Any, Dict, List
 
4
 
 
 
 
 
 
5
  from distilabel.llms.huggingface import InferenceEndpointsLLM
6
  from distilabel.pipeline import Pipeline
7
  from distilabel.steps import TextGenerationToArgilla
8
+ from distilabel.steps.expand import ExpandColumns
9
+ from distilabel.steps.generators.data import LoadDataFromDicts
10
+ from distilabel.steps.tasks.self_instruct import SelfInstruct
11
+ from distilabel.steps.tasks.text_generation import TextGeneration
12
+ from distilabel.steps.tasks.typing import ChatType
13
+
14
+
15
+ ################################################################################
16
+ # Functions to create task prompts
17
+ ################################################################################
18
+
19
+
20
+ def create_application_instruction(domain: str, examples: List[Dict[str, str]]):
21
+ """Create the instruction for Self-Instruct task."""
22
+ system_prompt = dedent(
23
+ f"""You are an AI assistant than generates queries around the domain of {domain}.
24
+ Your should not expect basic but profound questions from your users.
25
+ The queries should reflect a diversxamity of vision and economic positions and political positions.
26
+ The queries may know about different methods of {domain}.
27
+ The queries can be positioned politically, economically, socially, or practically.
28
+ Also take into account the impact of diverse causes on diverse domains."""
29
+ )
30
+ for example in examples:
31
+ question = example["question"]
32
+ answer = example["answer"]
33
+ system_prompt += f"""\n- Question: {question}\n- Answer: {answer}\n"""
34
+
35
+
36
+ def create_seed_terms(topics: List[str], perspectives: List[str]) -> List[str]:
37
+ """Create seed terms for self intruct to start from."""
38
+
39
+ return [
40
+ f"{topic} from a {perspective} perspective"
41
+ for topic in topics
42
+ for perspective in perspectives
43
+ ]
44
+
45
+
46
+ ################################################################################
47
+ # Define out custom step for the domain expert
48
+ ################################################################################
49
+
50
+
51
+ class DomainExpert(TextGeneration):
52
+ """A customized task to generate text as a domain expert in the domain of farming and agriculture."""
53
+
54
+ system_prompt: str
55
+ template: str = """This is the the instruction: {instruction}"""
56
+
57
+ def format_input(self, input: Dict[str, Any]) -> "ChatType":
58
+ return [
59
+ {
60
+ "role": "system",
61
+ "content": self.system_prompt,
62
+ },
63
+ {
64
+ "role": "user",
65
+ "content": self.template.format(**input),
66
+ },
67
+ ]
68
+
69
+
70
+ ################################################################################
71
+ # Main script to run the pipeline
72
+ ################################################################################
73
+
74
+
75
+ if __name__ == "__main__":
76
+
77
+ import os
78
+ import json
79
+
80
+ # load pipeline parameters
81
+
82
+ with open("pipeline_params.json", "r") as f:
83
+ params = json.load(f)
84
+
85
+ argilla_api_key = params.get("argilla_api_key")
86
+ argilla_api_url = params.get("argilla_api_url")
87
+ argilla_dataset_name = params.get("argilla_dataset_name")
88
+ endpoint_base_url = params.get("endpoint_base_url")
89
+ hub_token = os.environ.get("hub_token")
90
+
91
+ # collect our seed data
92
+
93
+ with open("seed_data.json", "r") as f:
94
+ seed_data = json.load(f)
95
+
96
+ topics = seed_data.get("topics", [])
97
+ perspectives = seed_data.get("perspectives", [])
98
+ domain_expert_prompt = seed_data.get("domain_expert_prompt", "")
99
+ examples = seed_data.get("examples", [])
100
+ domain_name = seed_data.get("domain_name", "domain")
101
+
102
+ # Define the task prompts
103
+
104
+ terms = create_seed_terms(topics=topics, perspectives=perspectives)
105
+ application_instruction = create_application_instruction(
106
+ domain=domain_name, examples=examples
107
+ )
108
+
109
+ # Define the distilabel pipeline
110
+
111
+ with Pipeline(domain_name) as pipeline:
112
  load_data = LoadDataFromDicts(
113
  name="load_data",
114
  data=[{"input": term} for term in terms],
115
  batch_size=64,
116
  )
117
+
 
 
 
118
  self_instruct = SelfInstruct(
119
+ name="self_instruct",
 
120
  num_instructions=5,
121
  input_batch_size=8,
122
+ llm=InferenceEndpointsLLM(
123
+ base_url=endpoint_base_url,
124
+ api_key=hub_token,
125
+ ),
 
 
 
 
 
 
 
126
  )
127
 
128
  expand_instructions = ExpandColumns(
129
+ name="expand_columns", columns={"instructions": "instruction"}
 
 
 
 
 
130
  )
131
 
132
  domain_expert = DomainExpert(
133
  name="domain_expert",
134
+ llm=InferenceEndpointsLLM(
135
+ base_url=endpoint_base_url,
136
+ api_key=hub_token,
137
+ ),
138
  input_batch_size=8,
139
+ system_prompt=domain_expert_prompt,
 
 
 
 
 
 
 
 
 
140
  )
141
 
142
  to_argilla = TextGenerationToArgilla(
 
145
  dataset_workspace="admin",
146
  api_url=argilla_api_url,
147
  api_key=argilla_api_key,
 
 
 
 
148
  )
149
 
150
+ # Connect up the pipeline
151
+
152
  load_data.connect(self_instruct)
153
  self_instruct.connect(expand_instructions)
154
+ expand_instructions.connect(domain_expert)
155
+ domain_expert.connect(to_argilla)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ # Run the pipeline
 
 
 
 
 
 
158
 
159
+ pipeline.run(
160
+ parameters={
161
+ "self_instruct": {
162
+ "llm": {"api_key": hub_token, "base_url": endpoint_base_url}
163
+ },
164
+ "domain_expert": {
165
+ "llm": {"api_key": hub_token, "base_url": endpoint_base_url}
166
+ },
167
+ "text_generation_to_argilla": {
168
+ "dataset_name": argilla_dataset_name,
169
+ "api_key": argilla_api_key,
170
+ "api_url": argilla_api_url,
171
+ },
172
+ },
173
+ use_cache=False,
174
+ )
utils.py CHANGED
@@ -26,8 +26,30 @@ def project_sidebar():
26
  )
27
  st.sidebar.link_button(f"πŸ“š Dataset Repo", DATASET_URL)
28
  st.sidebar.link_button(f"πŸ€– Argilla Space", ARGILLA_URL)
29
- st.sidebar.divider()
30
- st.sidebar.link_button("πŸ§‘β€πŸŒΎ New Project", DIBT_PARENT_APP_URL)
 
 
 
 
 
31
  st.sidebar.link_button(
32
  "πŸ€— Get your Hub Token", "https://huggingface.co/settings/tokens"
33
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  )
27
  st.sidebar.link_button(f"πŸ“š Dataset Repo", DATASET_URL)
28
  st.sidebar.link_button(f"πŸ€– Argilla Space", ARGILLA_URL)
29
+ hub_username = DATASET_REPO_ID.split("/")[0]
30
+ project_name = DATASET_REPO_ID.split("/")[1]
31
+ st.session_state["project_name"] = project_name
32
+ st.session_state["hub_username"] = hub_username
33
+ st.session_state["hub_token"] = st.sidebar.text_input(
34
+ "Hub Token", type="password", value=None
35
+ )
36
  st.sidebar.link_button(
37
  "πŸ€— Get your Hub Token", "https://huggingface.co/settings/tokens"
38
  )
39
+ if all(
40
+ (
41
+ st.session_state.get("project_name"),
42
+ st.session_state.get("hub_username"),
43
+ st.session_state.get("hub_token"),
44
+ )
45
+ ):
46
+ st.success(f"Using the dataset repo {hub_username}/{project_name} on the Hub")
47
+
48
+ st.sidebar.divider()
49
+
50
+ st.sidebar.link_button("πŸ§‘β€πŸŒΎ New Project", DIBT_PARENT_APP_URL)
51
+
52
+ if st.session_state["hub_token"] is None:
53
+ st.error("Please provide a Hub token to generate answers")
54
+ st.stop()
55
+