sdiazlor HF staff commited on
Commit
dea1102
1 Parent(s): 4e19310

update sft and use input parameters

Browse files
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -24,7 +24,6 @@ from src.distilabel_dataset_generator.pipelines.embeddings import (
24
  )
25
  from src.distilabel_dataset_generator.pipelines.sft import (
26
  DEFAULT_DATASET_DESCRIPTIONS,
27
- PROMPT_CREATION_PROMPT,
28
  generate_pipeline_code,
29
  get_magpie_generator,
30
  get_prompt_generator,
@@ -55,7 +54,6 @@ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
55
 
56
  def generate_system_prompt(dataset_description, temperature, progress=gr.Progress()):
57
  progress(0.0, desc="Generating system prompt")
58
-
59
  progress(0.3, desc="Initializing text generation")
60
  generate_description = get_prompt_generator(temperature)
61
  progress(0.7, desc="Generating system prompt")
@@ -63,20 +61,19 @@ def generate_system_prompt(dataset_description, temperature, progress=gr.Progres
63
  generate_description.process(
64
  [
65
  {
66
- "system_prompt": PROMPT_CREATION_PROMPT,
67
  "instruction": dataset_description,
68
  }
69
  ]
70
  )
71
  )[0]["generation"]
72
  progress(1.0, desc="System prompt generated")
73
- return result, pd.DataFrame()
74
 
75
 
76
- def generate_sample_dataset(system_prompt, progress=gr.Progress()):
77
  df = generate_dataset(
78
  system_prompt=system_prompt,
79
- num_turns=1,
80
  num_rows=10,
81
  progress=progress,
82
  is_sample=True,
@@ -92,10 +89,8 @@ def generate_dataset(
92
  progress=gr.Progress(),
93
  ) -> pd.DataFrame:
94
  progress(0.0, desc="(1/2) Generating instructions")
95
- magpie_generator = get_magpie_generator(
96
- num_turns, num_rows, system_prompt, is_sample
97
- )
98
- response_generator = get_response_generator(num_turns, system_prompt, is_sample)
99
  total_steps: int = num_rows * 2
100
  batch_size = DEFAULT_BATCH_SIZE
101
 
@@ -405,7 +400,7 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
405
  )
406
  with gr.Column(scale=3):
407
  dataframe = gr.Dataframe(
408
- headers=["prompt", "completion"], wrap=True, height=300
409
  )
410
 
411
  gr.HTML(value="<hr>")
@@ -450,15 +445,21 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
450
  label="Distilabel Pipeline Code",
451
  )
452
 
453
- gr.on(
454
- triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
455
  fn=generate_system_prompt,
456
  inputs=[dataset_description, temperature],
457
- outputs=[system_prompt, dataframe],
458
  show_progress=True,
459
  ).then(
460
  fn=generate_sample_dataset,
461
- inputs=[system_prompt],
 
 
 
 
 
 
 
462
  outputs=[dataframe],
463
  show_progress=True,
464
  )
 
24
  )
25
  from src.distilabel_dataset_generator.pipelines.sft import (
26
  DEFAULT_DATASET_DESCRIPTIONS,
 
27
  generate_pipeline_code,
28
  get_magpie_generator,
29
  get_prompt_generator,
 
54
 
55
  def generate_system_prompt(dataset_description, temperature, progress=gr.Progress()):
56
  progress(0.0, desc="Generating system prompt")
 
57
  progress(0.3, desc="Initializing text generation")
58
  generate_description = get_prompt_generator(temperature)
59
  progress(0.7, desc="Generating system prompt")
 
61
  generate_description.process(
62
  [
63
  {
 
64
  "instruction": dataset_description,
65
  }
66
  ]
67
  )
68
  )[0]["generation"]
69
  progress(1.0, desc="System prompt generated")
70
+ return result
71
 
72
 
73
+ def generate_sample_dataset(system_prompt, num_turns, progress=gr.Progress()):
74
  df = generate_dataset(
75
  system_prompt=system_prompt,
76
+ num_turns=num_turns,
77
  num_rows=10,
78
  progress=progress,
79
  is_sample=True,
 
89
  progress=gr.Progress(),
90
  ) -> pd.DataFrame:
91
  progress(0.0, desc="(1/2) Generating instructions")
92
+ magpie_generator = get_magpie_generator(system_prompt, num_turns, is_sample)
93
+ response_generator = get_response_generator(system_prompt, num_turns, is_sample)
 
 
94
  total_steps: int = num_rows * 2
95
  batch_size = DEFAULT_BATCH_SIZE
96
 
 
400
  )
401
  with gr.Column(scale=3):
402
  dataframe = gr.Dataframe(
403
+ headers=["prompt", "completion"], wrap=True, height=500, interactive=False
404
  )
405
 
406
  gr.HTML(value="<hr>")
 
445
  label="Distilabel Pipeline Code",
446
  )
447
 
448
+ load_btn.click(
 
449
  fn=generate_system_prompt,
450
  inputs=[dataset_description, temperature],
451
+ outputs=[system_prompt],
452
  show_progress=True,
453
  ).then(
454
  fn=generate_sample_dataset,
455
+ inputs=[system_prompt, num_turns],
456
+ outputs=[dataframe],
457
+ show_progress=True,
458
+ )
459
+
460
+ btn_apply_to_sample_dataset.click(
461
+ fn=generate_sample_dataset,
462
+ inputs=[system_prompt, num_turns],
463
  outputs=[dataframe],
464
  show_progress=True,
465
  )
src/distilabel_dataset_generator/pipelines/sft.py CHANGED
@@ -138,51 +138,6 @@ def _get_output_mappings(num_turns):
138
  return {"conversation": "messages"}
139
 
140
 
141
- def generate_pipeline_code(system_prompt, num_turns, num_rows):
142
- input_mappings = _get_output_mappings(num_turns)
143
- code = f"""
144
- # Requirements: `pip install distilabel[hf-inference-endpoints]`
145
- import os
146
- from distilabel.pipeline import Pipeline
147
- from distilabel.steps import KeepColumns
148
- from distilabel.steps.tasks import MagpieGenerator
149
- from distilabel.llms import InferenceEndpointsLLM
150
-
151
- MODEL = "{MODEL}"
152
- SYSTEM_PROMPT = "{system_prompt}"
153
- os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
154
-
155
- with Pipeline(name="sft") as pipeline:
156
- magpie = MagpieGenerator(
157
- llm=InferenceEndpointsLLM(
158
- model_id=MODEL,
159
- tokenizer_id=MODEL,
160
- magpie_pre_query_template="llama3",
161
- generation_kwargs={{
162
- "temperature": 0.9,
163
- "do_sample": True,
164
- "max_new_tokens": 2048,
165
- "stop_sequences": {_STOP_SEQUENCES}
166
- }},
167
- api_key=os.environ["HF_TOKEN"],
168
- ),
169
- n_turns={num_turns},
170
- num_rows={num_rows},
171
- batch_size=1,
172
- system_prompt=SYSTEM_PROMPT,
173
- output_mappings={input_mappings},
174
- )
175
- keep_columns = KeepColumns(
176
- columns={list(input_mappings.values())} + ["model_name"],
177
- )
178
- magpie.connect(keep_columns)
179
-
180
- if __name__ == "__main__":
181
- distiset = pipeline.run()
182
- """
183
- return code
184
-
185
-
186
  def get_prompt_generator(temperature):
187
  prompt_generator = TextGeneration(
188
  llm=InferenceEndpointsLLM(
@@ -195,13 +150,14 @@ def get_prompt_generator(temperature):
195
  "do_sample": True,
196
  },
197
  ),
 
198
  use_system_prompt=True,
199
  )
200
  prompt_generator.load()
201
  return prompt_generator
202
 
203
 
204
- def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
205
  input_mappings = _get_output_mappings(num_turns)
206
  output_mappings = input_mappings.copy()
207
  if num_turns == 1:
@@ -246,7 +202,7 @@ def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
246
  return magpie_generator
247
 
248
 
249
- def get_response_generator(num_turns, system_prompt, is_sample):
250
  if num_turns == 1:
251
  response_generator = TextGeneration(
252
  llm=InferenceEndpointsLLM(
@@ -278,3 +234,48 @@ def get_response_generator(num_turns, system_prompt, is_sample):
278
  )
279
  response_generator.load()
280
  return response_generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  return {"conversation": "messages"}
139
 
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  def get_prompt_generator(temperature):
142
  prompt_generator = TextGeneration(
143
  llm=InferenceEndpointsLLM(
 
150
  "do_sample": True,
151
  },
152
  ),
153
+ system_prompt=PROMPT_CREATION_PROMPT,
154
  use_system_prompt=True,
155
  )
156
  prompt_generator.load()
157
  return prompt_generator
158
 
159
 
160
+ def get_magpie_generator(system_prompt, num_turns, is_sample):
161
  input_mappings = _get_output_mappings(num_turns)
162
  output_mappings = input_mappings.copy()
163
  if num_turns == 1:
 
202
  return magpie_generator
203
 
204
 
205
+ def get_response_generator(system_prompt, num_turns, is_sample):
206
  if num_turns == 1:
207
  response_generator = TextGeneration(
208
  llm=InferenceEndpointsLLM(
 
234
  )
235
  response_generator.load()
236
  return response_generator
237
+
238
+
239
+ def generate_pipeline_code(system_prompt, num_turns, num_rows):
240
+ input_mappings = _get_output_mappings(num_turns)
241
+ code = f"""
242
+ # Requirements: `pip install distilabel[hf-inference-endpoints]`
243
+ import os
244
+ from distilabel.pipeline import Pipeline
245
+ from distilabel.steps import KeepColumns
246
+ from distilabel.steps.tasks import MagpieGenerator
247
+ from distilabel.llms import InferenceEndpointsLLM
248
+
249
+ MODEL = "{MODEL}"
250
+ SYSTEM_PROMPT = "{system_prompt}"
251
+ os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
252
+
253
+ with Pipeline(name="sft") as pipeline:
254
+ magpie = MagpieGenerator(
255
+ llm=InferenceEndpointsLLM(
256
+ model_id=MODEL,
257
+ tokenizer_id=MODEL,
258
+ magpie_pre_query_template="llama3",
259
+ generation_kwargs={{
260
+ "temperature": 0.9,
261
+ "do_sample": True,
262
+ "max_new_tokens": 2048,
263
+ "stop_sequences": {_STOP_SEQUENCES}
264
+ }},
265
+ api_key=os.environ["HF_TOKEN"],
266
+ ),
267
+ n_turns={num_turns},
268
+ num_rows={num_rows},
269
+ batch_size=1,
270
+ system_prompt=SYSTEM_PROMPT,
271
+ output_mappings={input_mappings},
272
+ )
273
+ keep_columns = KeepColumns(
274
+ columns={list(input_mappings.values())} + ["model_name"],
275
+ )
276
+ magpie.connect(keep_columns)
277
+
278
+ if __name__ == "__main__":
279
+ distiset = pipeline.run()
280
+ """
281
+ return code