sdiazlor HF staff commited on
Commit
857f1ba
1 Parent(s): 2673ebc

add temperature for system prompt

Browse files
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -53,11 +53,11 @@ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
53
  return dataframe
54
 
55
 
56
- def generate_system_prompt(dataset_description, progress=gr.Progress()):
57
  progress(0.0, desc="Generating system prompt")
58
 
59
  progress(0.3, desc="Initializing text generation")
60
- generate_description = get_prompt_generator()
61
  progress(0.7, desc="Generating system prompt")
62
  result = next(
63
  generate_description.process(
@@ -360,6 +360,15 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
360
  label="Dataset description",
361
  placeholder="Give a precise description of your desired dataset.",
362
  )
 
 
 
 
 
 
 
 
 
363
  load_btn = gr.Button(
364
  "Create dataset",
365
  variant="primary",
@@ -444,7 +453,7 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
444
  gr.on(
445
  triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
446
  fn=generate_system_prompt,
447
- inputs=[dataset_description],
448
  outputs=[system_prompt, dataframe],
449
  show_progress=True,
450
  ).then(
 
53
  return dataframe
54
 
55
 
56
+ def generate_system_prompt(dataset_description, temperature, progress=gr.Progress()):
57
  progress(0.0, desc="Generating system prompt")
58
 
59
  progress(0.3, desc="Initializing text generation")
60
+ generate_description = get_prompt_generator(temperature)
61
  progress(0.7, desc="Generating system prompt")
62
  result = next(
63
  generate_description.process(
 
360
  label="Dataset description",
361
  placeholder="Give a precise description of your desired dataset.",
362
  )
363
+ with gr.Accordion("Temperature", open=False):
364
+ temperature = gr.Slider(
365
+ minimum=0.1,
366
+ maximum=1,
367
+ value=0.8,
368
+ step=0.1,
369
+ interactive=True,
370
+ show_label=False,
371
+ )
372
  load_btn = gr.Button(
373
  "Create dataset",
374
  variant="primary",
 
453
  gr.on(
454
  triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
455
  fn=generate_system_prompt,
456
+ inputs=[dataset_description, temperature],
457
  outputs=[system_prompt, dataframe],
458
  show_progress=True,
459
  ).then(
src/distilabel_dataset_generator/apps/textcat.py CHANGED
@@ -39,10 +39,10 @@ from src.distilabel_dataset_generator.utils import (
39
  )
40
 
41
 
42
- def generate_system_prompt(dataset_description, progress=gr.Progress()):
43
  progress(0.0, desc="Generating text classification task")
44
  progress(0.3, desc="Initializing text generation")
45
- generate_description = get_prompt_generator()
46
  progress(0.7, desc="Generating text classification task")
47
  system_prompt = next(
48
  generate_description.process(
@@ -368,6 +368,15 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
368
  label="Dataset description",
369
  placeholder="Give a precise description of your desired dataset.",
370
  )
 
 
 
 
 
 
 
 
 
371
  load_btn = gr.Button(
372
  "Create dataset",
373
  variant="primary",
@@ -490,7 +499,7 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
490
  gr.on(
491
  triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
492
  fn=generate_system_prompt,
493
- inputs=[dataset_description],
494
  outputs=[system_prompt, dataframe],
495
  show_progress=True,
496
  ).then(
 
39
  )
40
 
41
 
42
+ def generate_system_prompt(dataset_description, temperature, progress=gr.Progress()):
43
  progress(0.0, desc="Generating text classification task")
44
  progress(0.3, desc="Initializing text generation")
45
+ generate_description = get_prompt_generator(temperature)
46
  progress(0.7, desc="Generating text classification task")
47
  system_prompt = next(
48
  generate_description.process(
 
368
  label="Dataset description",
369
  placeholder="Give a precise description of your desired dataset.",
370
  )
371
+ with gr.Accordion("Temperature", open=False):
372
+ temperature = gr.Slider(
373
+ minimum=0.1,
374
+ maximum=1,
375
+ value=0.8,
376
+ step=0.1,
377
+ interactive=True,
378
+ show_label=False,
379
+ )
380
  load_btn = gr.Button(
381
  "Create dataset",
382
  variant="primary",
 
499
  gr.on(
500
  triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
501
  fn=generate_system_prompt,
502
+ inputs=[dataset_description, temperature],
503
  outputs=[system_prompt, dataframe],
504
  show_progress=True,
505
  ).then(
src/distilabel_dataset_generator/pipelines/sft.py CHANGED
@@ -183,6 +183,24 @@ if __name__ == "__main__":
183
  return code
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
187
  input_mappings = _get_output_mappings(num_turns)
188
  output_mappings = input_mappings.copy()
@@ -260,21 +278,3 @@ def get_response_generator(num_turns, system_prompt, is_sample):
260
  )
261
  response_generator.load()
262
  return response_generator
263
-
264
-
265
- def get_prompt_generator():
266
- prompt_generator = TextGeneration(
267
- llm=InferenceEndpointsLLM(
268
- api_key=_get_next_api_key(),
269
- model_id=MODEL,
270
- tokenizer_id=MODEL,
271
- generation_kwargs={
272
- "temperature": 0.8,
273
- "max_new_tokens": 2048,
274
- "do_sample": True,
275
- },
276
- ),
277
- use_system_prompt=True,
278
- )
279
- prompt_generator.load()
280
- return prompt_generator
 
183
  return code
184
 
185
 
186
+ def get_prompt_generator(temperature):
187
+ prompt_generator = TextGeneration(
188
+ llm=InferenceEndpointsLLM(
189
+ api_key=_get_next_api_key(),
190
+ model_id=MODEL,
191
+ tokenizer_id=MODEL,
192
+ generation_kwargs={
193
+ "temperature": temperature,
194
+ "max_new_tokens": 2048,
195
+ "do_sample": True,
196
+ },
197
+ ),
198
+ use_system_prompt=True,
199
+ )
200
+ prompt_generator.load()
201
+ return prompt_generator
202
+
203
+
204
  def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
205
  input_mappings = _get_output_mappings(num_turns)
206
  output_mappings = input_mappings.copy()
 
278
  )
279
  response_generator.load()
280
  return response_generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/distilabel_dataset_generator/pipelines/textcat.py CHANGED
@@ -148,6 +148,24 @@ with Pipeline(name="textcat") as pipeline:
148
  )
149
 
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  def get_textcat_generator(difficulty, clarity, is_sample):
152
  textcat_generator = GenerateTextClassificationData(
153
  llm=InferenceEndpointsLLM(
@@ -188,21 +206,3 @@ def get_labeller_generator(system_prompt, labels, num_labels):
188
  )
189
  labeller_generator.load()
190
  return labeller_generator
191
-
192
-
193
- def get_prompt_generator():
194
- prompt_generator = TextGeneration(
195
- llm=InferenceEndpointsLLM(
196
- api_key=_get_next_api_key(),
197
- model_id=MODEL,
198
- tokenizer_id=MODEL,
199
- generation_kwargs={
200
- "temperature": 0.8,
201
- "max_new_tokens": 2048,
202
- "do_sample": True,
203
- },
204
- ),
205
- use_system_prompt=True,
206
- )
207
- prompt_generator.load()
208
- return prompt_generator
 
148
  )
149
 
150
 
151
+ def get_prompt_generator(temperature):
152
+ prompt_generator = TextGeneration(
153
+ llm=InferenceEndpointsLLM(
154
+ api_key=_get_next_api_key(),
155
+ model_id=MODEL,
156
+ tokenizer_id=MODEL,
157
+ generation_kwargs={
158
+ "temperature": temperature,
159
+ "max_new_tokens": 2048,
160
+ "do_sample": True,
161
+ },
162
+ ),
163
+ use_system_prompt=True,
164
+ )
165
+ prompt_generator.load()
166
+ return prompt_generator
167
+
168
+
169
  def get_textcat_generator(difficulty, clarity, is_sample):
170
  textcat_generator = GenerateTextClassificationData(
171
  llm=InferenceEndpointsLLM(
 
206
  )
207
  labeller_generator.load()
208
  return labeller_generator