davidberenstein1957 HF staff commited on
Commit
ded6d1c
·
1 Parent(s): 791a4a1

add system prompt rewriter for more dynamic generation SFT

Browse files
src/synthetic_dataset_generator/apps/sft.py CHANGED
@@ -1,4 +1,5 @@
1
  import ast
 
2
  import uuid
3
  from typing import Dict, List, Union
4
 
@@ -32,6 +33,7 @@ from synthetic_dataset_generator.pipelines.sft import (
32
  generate_pipeline_code,
33
  get_magpie_generator,
34
  get_prompt_generator,
 
35
  get_response_generator,
36
  )
37
  from synthetic_dataset_generator.utils import (
@@ -103,6 +105,7 @@ def generate_dataset(
103
  ) -> pd.DataFrame:
104
  num_rows = test_max_num_rows(num_rows)
105
  progress(0.0, desc="(1/2) Generating instructions")
 
106
  magpie_generator = get_magpie_generator(
107
  system_prompt, num_turns, temperature, is_sample
108
  )
@@ -112,6 +115,16 @@ def generate_dataset(
112
  total_steps: int = num_rows * 2
113
  batch_size = DEFAULT_BATCH_SIZE
114
 
 
 
 
 
 
 
 
 
 
 
115
  # create instructions
116
  n_processed = 0
117
  magpie_results = []
@@ -123,7 +136,8 @@ def generate_dataset(
123
  )
124
  remaining_rows = num_rows - n_processed
125
  batch_size = min(batch_size, remaining_rows)
126
- inputs = [{"system_prompt": system_prompt} for _ in range(batch_size)]
 
127
  batch = list(magpie_generator.process(inputs=inputs))
128
  magpie_results.extend(batch[0])
129
  n_processed += batch_size
@@ -487,7 +501,7 @@ with gr.Blocks() as app:
487
  with gr.Column(scale=3):
488
  success_message = gr.Markdown(
489
  visible=True,
490
- height=100, # don't remove this otherwise progress is not visible
491
  )
492
  with gr.Accordion(
493
  "Customize your pipeline with distilabel",
@@ -543,6 +557,7 @@ with gr.Blocks() as app:
543
  fn=hide_pipeline_code_visibility,
544
  inputs=[],
545
  outputs=[pipeline_code_ui],
 
546
  ).success(
547
  fn=push_dataset,
548
  inputs=[
 
1
  import ast
2
+ import random
3
  import uuid
4
  from typing import Dict, List, Union
5
 
 
33
  generate_pipeline_code,
34
  get_magpie_generator,
35
  get_prompt_generator,
36
+ get_prompt_rewriter,
37
  get_response_generator,
38
  )
39
  from synthetic_dataset_generator.utils import (
 
105
  ) -> pd.DataFrame:
106
  num_rows = test_max_num_rows(num_rows)
107
  progress(0.0, desc="(1/2) Generating instructions")
108
+ prompt_rewriter = get_prompt_rewriter()
109
  magpie_generator = get_magpie_generator(
110
  system_prompt, num_turns, temperature, is_sample
111
  )
 
115
  total_steps: int = num_rows * 2
116
  batch_size = DEFAULT_BATCH_SIZE
117
 
118
+ # create prompt rewrites
119
+ inputs = [
120
+ {
121
+ "instruction": f"Rewrite this prompt keeping the same structure but highlighting different aspects of the original without adding anything new. Original prompt: {system_prompt} Rewritten prompt: "
122
+ }
123
+ for i in range(int(num_rows / 50))
124
+ ]
125
+ batch = list(prompt_rewriter.process(inputs=inputs))
126
+ prompt_rewrites = [entry["generation"] for entry in batch[0]] + [system_prompt]
127
+
128
  # create instructions
129
  n_processed = 0
130
  magpie_results = []
 
136
  )
137
  remaining_rows = num_rows - n_processed
138
  batch_size = min(batch_size, remaining_rows)
139
+ rewritten_system_prompt = random.choice(prompt_rewrites)
140
+ inputs = [{"system_prompt": rewritten_system_prompt} for _ in range(batch_size)]
141
  batch = list(magpie_generator.process(inputs=inputs))
142
  magpie_results.extend(batch[0])
143
  n_processed += batch_size
 
501
  with gr.Column(scale=3):
502
  success_message = gr.Markdown(
503
  visible=True,
504
+ min_height=100, # don't remove this otherwise progress is not visible
505
  )
506
  with gr.Accordion(
507
  "Customize your pipeline with distilabel",
 
557
  fn=hide_pipeline_code_visibility,
558
  inputs=[],
559
  outputs=[pipeline_code_ui],
560
+ show_progress=True,
561
  ).success(
562
  fn=push_dataset,
563
  inputs=[
src/synthetic_dataset_generator/pipelines/sft.py CHANGED
@@ -175,12 +175,11 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
175
  generation_kwargs={
176
  "temperature": temperature,
177
  "do_sample": True,
178
- "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
179
  "stop_sequences": _STOP_SEQUENCES,
180
  },
181
  ),
182
  n_turns=num_turns,
183
- system_prompt=system_prompt,
184
  output_mappings=output_mappings,
185
  only_instruction=True,
186
  )
@@ -195,19 +194,34 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
195
  generation_kwargs={
196
  "temperature": temperature,
197
  "do_sample": True,
198
- "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
199
  "stop_sequences": _STOP_SEQUENCES,
200
  },
201
  ),
202
  end_with_user=True,
203
  n_turns=num_turns,
204
- system_prompt=system_prompt,
205
  output_mappings=output_mappings,
206
  )
207
  magpie_generator.load()
208
  return magpie_generator
209
 
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  def get_response_generator(system_prompt, num_turns, temperature, is_sample):
212
  if num_turns == 1:
213
  response_generator = TextGeneration(
@@ -218,7 +232,7 @@ def get_response_generator(system_prompt, num_turns, temperature, is_sample):
218
  api_key=_get_next_api_key(),
219
  generation_kwargs={
220
  "temperature": temperature,
221
- "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
222
  },
223
  ),
224
  system_prompt=system_prompt,
 
175
  generation_kwargs={
176
  "temperature": temperature,
177
  "do_sample": True,
178
+ "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.25),
179
  "stop_sequences": _STOP_SEQUENCES,
180
  },
181
  ),
182
  n_turns=num_turns,
 
183
  output_mappings=output_mappings,
184
  only_instruction=True,
185
  )
 
194
  generation_kwargs={
195
  "temperature": temperature,
196
  "do_sample": True,
197
+ "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
198
  "stop_sequences": _STOP_SEQUENCES,
199
  },
200
  ),
201
  end_with_user=True,
202
  n_turns=num_turns,
 
203
  output_mappings=output_mappings,
204
  )
205
  magpie_generator.load()
206
  return magpie_generator
207
 
208
 
209
+ def get_prompt_rewriter():
210
+ prompt_rewriter = TextGeneration(
211
+ llm=InferenceEndpointsLLM(
212
+ model_id=MODEL,
213
+ tokenizer_id=MODEL,
214
+ base_url=BASE_URL,
215
+ api_key=_get_next_api_key(),
216
+ generation_kwargs={
217
+ "temperature": 1,
218
+ },
219
+ ),
220
+ )
221
+ prompt_rewriter.load()
222
+ return prompt_rewriter
223
+
224
+
225
  def get_response_generator(system_prompt, num_turns, temperature, is_sample):
226
  if num_turns == 1:
227
  response_generator = TextGeneration(
 
232
  api_key=_get_next_api_key(),
233
  generation_kwargs={
234
  "temperature": temperature,
235
+ "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
236
  },
237
  ),
238
  system_prompt=system_prompt,