davidberenstein1957 HF staff commited on
Commit
dd0124d
·
2 Parent(s): 9029def 3922cde

Merge branch 'main' of https://github.com/argilla-io/synthetic-data-generator

Browse files
src/synthetic_dataset_generator/app.py CHANGED
@@ -1,5 +1,5 @@
1
  from synthetic_dataset_generator._tabbedinterface import TabbedInterface
2
- from synthetic_dataset_generator.apps.eval import app as eval_app
3
  from synthetic_dataset_generator.apps.readme import app as readme_app
4
  from synthetic_dataset_generator.apps.sft import app as sft_app
5
  from synthetic_dataset_generator.apps.textcat import app as textcat_app
@@ -23,8 +23,8 @@ button[role="tab"][aria-selected="true"]:hover {border-color: var(--button-prima
23
  image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
24
 
25
  demo = TabbedInterface(
26
- [textcat_app, sft_app, eval_app, readme_app],
27
- ["Text Classification", "Supervised Fine-Tuning", "Evaluation", "README"],
28
  css=css,
29
  title=image,
30
  head="Synthetic Data Generator",
 
1
  from synthetic_dataset_generator._tabbedinterface import TabbedInterface
2
+ # from synthetic_dataset_generator.apps.eval import app as eval_app
3
  from synthetic_dataset_generator.apps.readme import app as readme_app
4
  from synthetic_dataset_generator.apps.sft import app as sft_app
5
  from synthetic_dataset_generator.apps.textcat import app as textcat_app
 
23
  image = """<br><img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: clamp(50%, 400px, 100%)"/>"""
24
 
25
  demo = TabbedInterface(
26
+ [textcat_app, sft_app, readme_app],
27
+ ["Text Classification", "Supervised Fine-Tuning", "README"],
28
  css=css,
29
  title=image,
30
  head="Synthetic Data Generator",
src/synthetic_dataset_generator/apps/base.py CHANGED
@@ -67,50 +67,6 @@ def push_pipeline_code_to_hub(
67
  progress(1.0, desc="Pipeline code uploaded")
68
 
69
 
70
- def push_dataset_to_hub(
71
- dataframe: pd.DataFrame,
72
- private: bool = True,
73
- org_name: str = None,
74
- repo_name: str = None,
75
- oauth_token: Union[OAuthToken, None] = None,
76
- progress=gr.Progress(),
77
- labels: List[str] = None,
78
- num_labels: int = None,
79
- task: str = TEXTCAT_TASK,
80
- ) -> pd.DataFrame:
81
- progress(0.1, desc="Setting up dataset")
82
- repo_id = validate_push_to_hub(org_name, repo_name)
83
-
84
- if task == TEXTCAT_TASK:
85
- if num_labels == 1:
86
- dataframe["label"] = dataframe["label"].replace("", None)
87
- features = Features(
88
- {"text": Value("string"), "label": ClassLabel(names=labels)}
89
- )
90
- else:
91
- features = Features(
92
- {
93
- "text": Value("string"),
94
- "labels": Sequence(feature=ClassLabel(names=labels)),
95
- }
96
- )
97
- distiset = Distiset(
98
- {"default": Dataset.from_pandas(dataframe, features=features)}
99
- )
100
- else:
101
- distiset = Distiset({"default": Dataset.from_pandas(dataframe)})
102
- progress(0.2, desc="Pushing dataset to hub")
103
- distiset.push_to_hub(
104
- repo_id=repo_id,
105
- private=private,
106
- include_script=False,
107
- token=oauth_token.token,
108
- create_pr=False,
109
- )
110
- progress(1.0, desc="Dataset pushed to hub")
111
- return dataframe
112
-
113
-
114
  def validate_push_to_hub(org_name, repo_name):
115
  repo_id = (
116
  f"{org_name}/{repo_name}"
 
67
  progress(1.0, desc="Pipeline code uploaded")
68
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def validate_push_to_hub(org_name, repo_name):
71
  repo_id = (
72
  f"{org_name}/{repo_name}"
src/synthetic_dataset_generator/apps/sft.py CHANGED
@@ -15,7 +15,7 @@ from synthetic_dataset_generator.apps.base import (
15
  validate_argilla_user_workspace_dataset,
16
  validate_push_to_hub,
17
  )
18
- from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, SFT_AVAILABLE
19
  from synthetic_dataset_generator.pipelines.embeddings import (
20
  get_embeddings,
21
  get_sentence_embedding_dimensions,
@@ -49,10 +49,10 @@ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
49
  return dataframe
50
 
51
 
52
- def generate_system_prompt(dataset_description, temperature, progress=gr.Progress()):
53
  progress(0.0, desc="Generating system prompt")
54
  progress(0.3, desc="Initializing text generation")
55
- generate_description = get_prompt_generator(temperature)
56
  progress(0.7, desc="Generating system prompt")
57
  result = next(
58
  generate_description.process(
@@ -92,12 +92,13 @@ def generate_dataset(
92
  system_prompt: str,
93
  num_turns: int = 1,
94
  num_rows: int = 10,
 
95
  is_sample: bool = False,
96
  progress=gr.Progress(),
97
  ) -> pd.DataFrame:
98
  progress(0.0, desc="(1/2) Generating instructions")
99
- magpie_generator = get_magpie_generator(system_prompt, num_turns, is_sample)
100
- response_generator = get_response_generator(system_prompt, num_turns, is_sample)
101
  total_steps: int = num_rows * 2
102
  batch_size = DEFAULT_BATCH_SIZE
103
 
@@ -216,6 +217,7 @@ def push_dataset(
216
  num_turns: int = 1,
217
  num_rows: int = 10,
218
  private: bool = False,
 
219
  oauth_token: Union[gr.OAuthToken, None] = None,
220
  progress=gr.Progress(),
221
  ) -> pd.DataFrame:
@@ -223,6 +225,7 @@ def push_dataset(
223
  system_prompt=system_prompt,
224
  num_turns=num_turns,
225
  num_rows=num_rows,
 
226
  )
227
  push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
228
  try:
@@ -439,7 +442,7 @@ with gr.Blocks() as app:
439
  label="Temperature",
440
  minimum=0.1,
441
  maximum=1,
442
- value=0.8,
443
  step=0.1,
444
  interactive=True,
445
  )
@@ -463,6 +466,7 @@ with gr.Blocks() as app:
463
  system_prompt=system_prompt.value,
464
  num_turns=num_turns.value,
465
  num_rows=num_rows.value,
 
466
  )
467
  pipeline_code = gr.Code(
468
  value=code,
@@ -472,7 +476,7 @@ with gr.Blocks() as app:
472
 
473
  load_btn.click(
474
  fn=generate_system_prompt,
475
- inputs=[dataset_description, temperature],
476
  outputs=[system_prompt],
477
  show_progress=True,
478
  ).then(
@@ -516,6 +520,7 @@ with gr.Blocks() as app:
516
  num_turns,
517
  num_rows,
518
  private,
 
519
  ],
520
  outputs=[success_message],
521
  show_progress=True,
@@ -525,7 +530,7 @@ with gr.Blocks() as app:
525
  outputs=[success_message],
526
  ).success(
527
  fn=generate_pipeline_code,
528
- inputs=[system_prompt, num_turns, num_rows],
529
  outputs=[pipeline_code],
530
  ).success(
531
  fn=show_pipeline_code_visibility,
 
15
  validate_argilla_user_workspace_dataset,
16
  validate_push_to_hub,
17
  )
18
+ from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, SFT_AVAILABLE, MODEL
19
  from synthetic_dataset_generator.pipelines.embeddings import (
20
  get_embeddings,
21
  get_sentence_embedding_dimensions,
 
49
  return dataframe
50
 
51
 
52
+ def generate_system_prompt(dataset_description, progress=gr.Progress()):
53
  progress(0.0, desc="Generating system prompt")
54
  progress(0.3, desc="Initializing text generation")
55
+ generate_description = get_prompt_generator()
56
  progress(0.7, desc="Generating system prompt")
57
  result = next(
58
  generate_description.process(
 
92
  system_prompt: str,
93
  num_turns: int = 1,
94
  num_rows: int = 10,
95
+ temperature: float = 0.9,
96
  is_sample: bool = False,
97
  progress=gr.Progress(),
98
  ) -> pd.DataFrame:
99
  progress(0.0, desc="(1/2) Generating instructions")
100
+ magpie_generator = get_magpie_generator(system_prompt, num_turns, temperature, is_sample)
101
+ response_generator = get_response_generator(system_prompt, num_turns, temperature, is_sample)
102
  total_steps: int = num_rows * 2
103
  batch_size = DEFAULT_BATCH_SIZE
104
 
 
217
  num_turns: int = 1,
218
  num_rows: int = 10,
219
  private: bool = False,
220
+ temperature: float = 0.9,
221
  oauth_token: Union[gr.OAuthToken, None] = None,
222
  progress=gr.Progress(),
223
  ) -> pd.DataFrame:
 
225
  system_prompt=system_prompt,
226
  num_turns=num_turns,
227
  num_rows=num_rows,
228
+ temperature=temperature,
229
  )
230
  push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
231
  try:
 
442
  label="Temperature",
443
  minimum=0.1,
444
  maximum=1,
445
+ value=0.9,
446
  step=0.1,
447
  interactive=True,
448
  )
 
466
  system_prompt=system_prompt.value,
467
  num_turns=num_turns.value,
468
  num_rows=num_rows.value,
469
+ temperature=temperature.value,
470
  )
471
  pipeline_code = gr.Code(
472
  value=code,
 
476
 
477
  load_btn.click(
478
  fn=generate_system_prompt,
479
+ inputs=[dataset_description],
480
  outputs=[system_prompt],
481
  show_progress=True,
482
  ).then(
 
520
  num_turns,
521
  num_rows,
522
  private,
523
+ temperature
524
  ],
525
  outputs=[success_message],
526
  show_progress=True,
 
530
  outputs=[success_message],
531
  ).success(
532
  fn=generate_pipeline_code,
533
+ inputs=[system_prompt, num_turns, num_rows, temperature],
534
  outputs=[pipeline_code],
535
  ).success(
536
  fn=show_pipeline_code_visibility,
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -45,10 +45,10 @@ def _get_dataframe():
45
  )
46
 
47
 
48
- def generate_system_prompt(dataset_description, temperature, progress=gr.Progress()):
49
  progress(0.0, desc="Generating text classification task")
50
  progress(0.3, desc="Initializing text generation")
51
- generate_description = get_prompt_generator(temperature)
52
  progress(0.7, desc="Generating text classification task")
53
  result = next(
54
  generate_description.process(
@@ -89,13 +89,14 @@ def generate_dataset(
89
  labels: List[str] = None,
90
  num_labels: int = 1,
91
  num_rows: int = 10,
 
92
  is_sample: bool = False,
93
  progress=gr.Progress(),
94
  ) -> pd.DataFrame:
95
  progress(0.0, desc="(1/2) Generating text classification data")
96
  labels = get_preprocess_labels(labels)
97
  textcat_generator = get_textcat_generator(
98
- difficulty=difficulty, clarity=clarity, is_sample=is_sample
99
  )
100
  labeller_generator = get_labeller_generator(
101
  system_prompt=f"{system_prompt} {', '.join(labels)}",
@@ -204,6 +205,7 @@ def push_dataset(
204
  num_rows: int = 10,
205
  labels: List[str] = None,
206
  private: bool = False,
 
207
  oauth_token: Union[gr.OAuthToken, None] = None,
208
  progress=gr.Progress(),
209
  ) -> pd.DataFrame:
@@ -214,6 +216,7 @@ def push_dataset(
214
  num_labels=num_labels,
215
  labels=labels,
216
  num_rows=num_rows,
 
217
  )
218
  push_dataset_to_hub(
219
  dataframe, org_name, repo_name, num_labels, labels, oauth_token, private
@@ -471,6 +474,7 @@ with gr.Blocks() as app:
471
  labels=labels.value,
472
  num_labels=num_labels.value,
473
  num_rows=num_rows.value,
 
474
  )
475
  pipeline_code = gr.Code(
476
  value=code,
@@ -480,7 +484,7 @@ with gr.Blocks() as app:
480
 
481
  load_btn.click(
482
  fn=generate_system_prompt,
483
- inputs=[dataset_description, temperature],
484
  outputs=[system_prompt, labels],
485
  show_progress=True,
486
  ).then(
@@ -537,6 +541,7 @@ with gr.Blocks() as app:
537
  num_rows,
538
  labels,
539
  private,
 
540
  ],
541
  outputs=[success_message],
542
  show_progress=True,
@@ -553,6 +558,7 @@ with gr.Blocks() as app:
553
  labels,
554
  num_labels,
555
  num_rows,
 
556
  ],
557
  outputs=[pipeline_code],
558
  ).success(
 
45
  )
46
 
47
 
48
+ def generate_system_prompt(dataset_description, progress=gr.Progress()):
49
  progress(0.0, desc="Generating text classification task")
50
  progress(0.3, desc="Initializing text generation")
51
+ generate_description = get_prompt_generator()
52
  progress(0.7, desc="Generating text classification task")
53
  result = next(
54
  generate_description.process(
 
89
  labels: List[str] = None,
90
  num_labels: int = 1,
91
  num_rows: int = 10,
92
+ temperature: float = 0.9,
93
  is_sample: bool = False,
94
  progress=gr.Progress(),
95
  ) -> pd.DataFrame:
96
  progress(0.0, desc="(1/2) Generating text classification data")
97
  labels = get_preprocess_labels(labels)
98
  textcat_generator = get_textcat_generator(
99
+ difficulty=difficulty, clarity=clarity, temperature=temperature, is_sample=is_sample
100
  )
101
  labeller_generator = get_labeller_generator(
102
  system_prompt=f"{system_prompt} {', '.join(labels)}",
 
205
  num_rows: int = 10,
206
  labels: List[str] = None,
207
  private: bool = False,
208
+ temperature: float = 0.8,
209
  oauth_token: Union[gr.OAuthToken, None] = None,
210
  progress=gr.Progress(),
211
  ) -> pd.DataFrame:
 
216
  num_labels=num_labels,
217
  labels=labels,
218
  num_rows=num_rows,
219
+ temperature=temperature,
220
  )
221
  push_dataset_to_hub(
222
  dataframe, org_name, repo_name, num_labels, labels, oauth_token, private
 
474
  labels=labels.value,
475
  num_labels=num_labels.value,
476
  num_rows=num_rows.value,
477
+ temperature=temperature.value,
478
  )
479
  pipeline_code = gr.Code(
480
  value=code,
 
484
 
485
  load_btn.click(
486
  fn=generate_system_prompt,
487
+ inputs=[dataset_description],
488
  outputs=[system_prompt, labels],
489
  show_progress=True,
490
  ).then(
 
541
  num_rows,
542
  labels,
543
  private,
544
+ temperature
545
  ],
546
  outputs=[success_message],
547
  show_progress=True,
 
558
  labels,
559
  num_labels,
560
  num_rows,
561
+ temperature
562
  ],
563
  outputs=[pipeline_code],
564
  ).success(
src/synthetic_dataset_generator/pipelines/sft.py CHANGED
@@ -140,7 +140,7 @@ def _get_output_mappings(num_turns):
140
  return {"conversation": "messages"}
141
 
142
 
143
- def get_prompt_generator(temperature):
144
  prompt_generator = TextGeneration(
145
  llm=InferenceEndpointsLLM(
146
  api_key=_get_next_api_key(),
@@ -148,7 +148,7 @@ def get_prompt_generator(temperature):
148
  tokenizer_id=MODEL,
149
  base_url=BASE_URL,
150
  generation_kwargs={
151
- "temperature": temperature,
152
  "max_new_tokens": 2048,
153
  "do_sample": True,
154
  },
@@ -160,7 +160,7 @@ def get_prompt_generator(temperature):
160
  return prompt_generator
161
 
162
 
163
- def get_magpie_generator(system_prompt, num_turns, is_sample):
164
  input_mappings = _get_output_mappings(num_turns)
165
  output_mappings = input_mappings.copy()
166
  if num_turns == 1:
@@ -172,7 +172,7 @@ def get_magpie_generator(system_prompt, num_turns, is_sample):
172
  api_key=_get_next_api_key(),
173
  magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
174
  generation_kwargs={
175
- "temperature": 0.9,
176
  "do_sample": True,
177
  "max_new_tokens": 256 if is_sample else 512,
178
  "stop_sequences": _STOP_SEQUENCES,
@@ -192,7 +192,7 @@ def get_magpie_generator(system_prompt, num_turns, is_sample):
192
  api_key=_get_next_api_key(),
193
  magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
194
  generation_kwargs={
195
- "temperature": 0.9,
196
  "do_sample": True,
197
  "max_new_tokens": 256 if is_sample else 1024,
198
  "stop_sequences": _STOP_SEQUENCES,
@@ -243,7 +243,7 @@ def get_response_generator(system_prompt, num_turns, is_sample):
243
  return response_generator
244
 
245
 
246
- def generate_pipeline_code(system_prompt, num_turns, num_rows):
247
  input_mappings = _get_output_mappings(num_turns)
248
  code = f"""
249
  # Requirements: `pip install distilabel[hf-inference-endpoints]`
@@ -266,7 +266,7 @@ with Pipeline(name="sft") as pipeline:
266
  base_url=BASE_URL,
267
  magpie_pre_query_template="llama3",
268
  generation_kwargs={{
269
- "temperature": 0.9,
270
  "do_sample": True,
271
  "max_new_tokens": 2048,
272
  "stop_sequences": {_STOP_SEQUENCES}
 
140
  return {"conversation": "messages"}
141
 
142
 
143
+ def get_prompt_generator():
144
  prompt_generator = TextGeneration(
145
  llm=InferenceEndpointsLLM(
146
  api_key=_get_next_api_key(),
 
148
  tokenizer_id=MODEL,
149
  base_url=BASE_URL,
150
  generation_kwargs={
151
+ "temperature": 0.8,
152
  "max_new_tokens": 2048,
153
  "do_sample": True,
154
  },
 
160
  return prompt_generator
161
 
162
 
163
+ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
164
  input_mappings = _get_output_mappings(num_turns)
165
  output_mappings = input_mappings.copy()
166
  if num_turns == 1:
 
172
  api_key=_get_next_api_key(),
173
  magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
174
  generation_kwargs={
175
+ "temperature": temperature,
176
  "do_sample": True,
177
  "max_new_tokens": 256 if is_sample else 512,
178
  "stop_sequences": _STOP_SEQUENCES,
 
192
  api_key=_get_next_api_key(),
193
  magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
194
  generation_kwargs={
195
+ "temperature": temperature,
196
  "do_sample": True,
197
  "max_new_tokens": 256 if is_sample else 1024,
198
  "stop_sequences": _STOP_SEQUENCES,
 
243
  return response_generator
244
 
245
 
246
+ def generate_pipeline_code(system_prompt, num_turns, num_rows, temperature):
247
  input_mappings = _get_output_mappings(num_turns)
248
  code = f"""
249
  # Requirements: `pip install distilabel[hf-inference-endpoints]`
 
266
  base_url=BASE_URL,
267
  magpie_pre_query_template="llama3",
268
  generation_kwargs={{
269
+ "temperature": {temperature},
270
  "do_sample": True,
271
  "max_new_tokens": 2048,
272
  "stop_sequences": {_STOP_SEQUENCES}
src/synthetic_dataset_generator/pipelines/textcat.py CHANGED
@@ -66,7 +66,7 @@ class TextClassificationTask(BaseModel):
66
  )
67
 
68
 
69
- def get_prompt_generator(temperature):
70
  prompt_generator = TextGeneration(
71
  llm=InferenceEndpointsLLM(
72
  api_key=_get_next_api_key(),
@@ -74,7 +74,7 @@ def get_prompt_generator(temperature):
74
  base_url=BASE_URL,
75
  structured_output={"format": "json", "schema": TextClassificationTask},
76
  generation_kwargs={
77
- "temperature": temperature,
78
  "max_new_tokens": 2048,
79
  "do_sample": True,
80
  },
@@ -86,14 +86,14 @@ def get_prompt_generator(temperature):
86
  return prompt_generator
87
 
88
 
89
- def get_textcat_generator(difficulty, clarity, is_sample):
90
  textcat_generator = GenerateTextClassificationData(
91
  llm=InferenceEndpointsLLM(
92
  model_id=MODEL,
93
  base_url=BASE_URL,
94
  api_key=_get_next_api_key(),
95
  generation_kwargs={
96
- "temperature": 0.9,
97
  "max_new_tokens": 256 if is_sample else 2048,
98
  "do_sample": True,
99
  "top_k": 50,
@@ -135,6 +135,7 @@ def generate_pipeline_code(
135
  labels: List[str] = None,
136
  num_labels: int = 1,
137
  num_rows: int = 10,
 
138
  ) -> str:
139
  labels = get_preprocess_labels(labels)
140
  base_code = f"""
@@ -163,7 +164,7 @@ with Pipeline(name="textcat") as pipeline:
163
  base_url=BASE_URL,
164
  api_key=os.environ["API_KEY"],
165
  generation_kwargs={{
166
- "temperature": 0.8,
167
  "max_new_tokens": 2048,
168
  "do_sample": True,
169
  "top_k": 50,
 
66
  )
67
 
68
 
69
+ def get_prompt_generator():
70
  prompt_generator = TextGeneration(
71
  llm=InferenceEndpointsLLM(
72
  api_key=_get_next_api_key(),
 
74
  base_url=BASE_URL,
75
  structured_output={"format": "json", "schema": TextClassificationTask},
76
  generation_kwargs={
77
+ "temperature": 0.8,
78
  "max_new_tokens": 2048,
79
  "do_sample": True,
80
  },
 
86
  return prompt_generator
87
 
88
 
89
+ def get_textcat_generator(difficulty, clarity, temperature, is_sample):
90
  textcat_generator = GenerateTextClassificationData(
91
  llm=InferenceEndpointsLLM(
92
  model_id=MODEL,
93
  base_url=BASE_URL,
94
  api_key=_get_next_api_key(),
95
  generation_kwargs={
96
+ "temperature": temperature,
97
  "max_new_tokens": 256 if is_sample else 2048,
98
  "do_sample": True,
99
  "top_k": 50,
 
135
  labels: List[str] = None,
136
  num_labels: int = 1,
137
  num_rows: int = 10,
138
+ temperature: float = 0.9,
139
  ) -> str:
140
  labels = get_preprocess_labels(labels)
141
  base_code = f"""
 
164
  base_url=BASE_URL,
165
  api_key=os.environ["API_KEY"],
166
  generation_kwargs={{
167
+ "temperature": {temperature},
168
  "max_new_tokens": 2048,
169
  "do_sample": True,
170
  "top_k": 50,