add temperature for system prompt
Browse files
src/distilabel_dataset_generator/apps/sft.py
CHANGED
@@ -53,11 +53,11 @@ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
|
|
53 |
return dataframe
|
54 |
|
55 |
|
56 |
-
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
57 |
progress(0.0, desc="Generating system prompt")
|
58 |
|
59 |
progress(0.3, desc="Initializing text generation")
|
60 |
-
generate_description = get_prompt_generator()
|
61 |
progress(0.7, desc="Generating system prompt")
|
62 |
result = next(
|
63 |
generate_description.process(
|
@@ -360,6 +360,15 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
|
|
360 |
label="Dataset description",
|
361 |
placeholder="Give a precise description of your desired dataset.",
|
362 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
load_btn = gr.Button(
|
364 |
"Create dataset",
|
365 |
variant="primary",
|
@@ -444,7 +453,7 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
|
|
444 |
gr.on(
|
445 |
triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
|
446 |
fn=generate_system_prompt,
|
447 |
-
inputs=[dataset_description],
|
448 |
outputs=[system_prompt, dataframe],
|
449 |
show_progress=True,
|
450 |
).then(
|
|
|
53 |
return dataframe
|
54 |
|
55 |
|
56 |
+
def generate_system_prompt(dataset_description, temperature, progress=gr.Progress()):
|
57 |
progress(0.0, desc="Generating system prompt")
|
58 |
|
59 |
progress(0.3, desc="Initializing text generation")
|
60 |
+
generate_description = get_prompt_generator(temperature)
|
61 |
progress(0.7, desc="Generating system prompt")
|
62 |
result = next(
|
63 |
generate_description.process(
|
|
|
360 |
label="Dataset description",
|
361 |
placeholder="Give a precise description of your desired dataset.",
|
362 |
)
|
363 |
+
with gr.Accordion("Temperature", open=False):
|
364 |
+
temperature = gr.Slider(
|
365 |
+
minimum=0.1,
|
366 |
+
maximum=1,
|
367 |
+
value=0.8,
|
368 |
+
step=0.1,
|
369 |
+
interactive=True,
|
370 |
+
show_label=False,
|
371 |
+
)
|
372 |
load_btn = gr.Button(
|
373 |
"Create dataset",
|
374 |
variant="primary",
|
|
|
453 |
gr.on(
|
454 |
triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
|
455 |
fn=generate_system_prompt,
|
456 |
+
inputs=[dataset_description, temperature],
|
457 |
outputs=[system_prompt, dataframe],
|
458 |
show_progress=True,
|
459 |
).then(
|
src/distilabel_dataset_generator/apps/textcat.py
CHANGED
@@ -39,10 +39,10 @@ from src.distilabel_dataset_generator.utils import (
|
|
39 |
)
|
40 |
|
41 |
|
42 |
-
def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
43 |
progress(0.0, desc="Generating text classification task")
|
44 |
progress(0.3, desc="Initializing text generation")
|
45 |
-
generate_description = get_prompt_generator()
|
46 |
progress(0.7, desc="Generating text classification task")
|
47 |
system_prompt = next(
|
48 |
generate_description.process(
|
@@ -368,6 +368,15 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
|
|
368 |
label="Dataset description",
|
369 |
placeholder="Give a precise description of your desired dataset.",
|
370 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
load_btn = gr.Button(
|
372 |
"Create dataset",
|
373 |
variant="primary",
|
@@ -490,7 +499,7 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
|
|
490 |
gr.on(
|
491 |
triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
|
492 |
fn=generate_system_prompt,
|
493 |
-
inputs=[dataset_description],
|
494 |
outputs=[system_prompt, dataframe],
|
495 |
show_progress=True,
|
496 |
).then(
|
|
|
39 |
)
|
40 |
|
41 |
|
42 |
+
def generate_system_prompt(dataset_description, temperature, progress=gr.Progress()):
|
43 |
progress(0.0, desc="Generating text classification task")
|
44 |
progress(0.3, desc="Initializing text generation")
|
45 |
+
generate_description = get_prompt_generator(temperature)
|
46 |
progress(0.7, desc="Generating text classification task")
|
47 |
system_prompt = next(
|
48 |
generate_description.process(
|
|
|
368 |
label="Dataset description",
|
369 |
placeholder="Give a precise description of your desired dataset.",
|
370 |
)
|
371 |
+
with gr.Accordion("Temperature", open=False):
|
372 |
+
temperature = gr.Slider(
|
373 |
+
minimum=0.1,
|
374 |
+
maximum=1,
|
375 |
+
value=0.8,
|
376 |
+
step=0.1,
|
377 |
+
interactive=True,
|
378 |
+
show_label=False,
|
379 |
+
)
|
380 |
load_btn = gr.Button(
|
381 |
"Create dataset",
|
382 |
variant="primary",
|
|
|
499 |
gr.on(
|
500 |
triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
|
501 |
fn=generate_system_prompt,
|
502 |
+
inputs=[dataset_description, temperature],
|
503 |
outputs=[system_prompt, dataframe],
|
504 |
show_progress=True,
|
505 |
).then(
|
src/distilabel_dataset_generator/pipelines/sft.py
CHANGED
@@ -183,6 +183,24 @@ if __name__ == "__main__":
|
|
183 |
return code
|
184 |
|
185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
|
187 |
input_mappings = _get_output_mappings(num_turns)
|
188 |
output_mappings = input_mappings.copy()
|
@@ -260,21 +278,3 @@ def get_response_generator(num_turns, system_prompt, is_sample):
|
|
260 |
)
|
261 |
response_generator.load()
|
262 |
return response_generator
|
263 |
-
|
264 |
-
|
265 |
-
def get_prompt_generator():
|
266 |
-
prompt_generator = TextGeneration(
|
267 |
-
llm=InferenceEndpointsLLM(
|
268 |
-
api_key=_get_next_api_key(),
|
269 |
-
model_id=MODEL,
|
270 |
-
tokenizer_id=MODEL,
|
271 |
-
generation_kwargs={
|
272 |
-
"temperature": 0.8,
|
273 |
-
"max_new_tokens": 2048,
|
274 |
-
"do_sample": True,
|
275 |
-
},
|
276 |
-
),
|
277 |
-
use_system_prompt=True,
|
278 |
-
)
|
279 |
-
prompt_generator.load()
|
280 |
-
return prompt_generator
|
|
|
183 |
return code
|
184 |
|
185 |
|
186 |
+
def get_prompt_generator(temperature):
|
187 |
+
prompt_generator = TextGeneration(
|
188 |
+
llm=InferenceEndpointsLLM(
|
189 |
+
api_key=_get_next_api_key(),
|
190 |
+
model_id=MODEL,
|
191 |
+
tokenizer_id=MODEL,
|
192 |
+
generation_kwargs={
|
193 |
+
"temperature": temperature,
|
194 |
+
"max_new_tokens": 2048,
|
195 |
+
"do_sample": True,
|
196 |
+
},
|
197 |
+
),
|
198 |
+
use_system_prompt=True,
|
199 |
+
)
|
200 |
+
prompt_generator.load()
|
201 |
+
return prompt_generator
|
202 |
+
|
203 |
+
|
204 |
def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
|
205 |
input_mappings = _get_output_mappings(num_turns)
|
206 |
output_mappings = input_mappings.copy()
|
|
|
278 |
)
|
279 |
response_generator.load()
|
280 |
return response_generator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/distilabel_dataset_generator/pipelines/textcat.py
CHANGED
@@ -148,6 +148,24 @@ with Pipeline(name="textcat") as pipeline:
|
|
148 |
)
|
149 |
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
def get_textcat_generator(difficulty, clarity, is_sample):
|
152 |
textcat_generator = GenerateTextClassificationData(
|
153 |
llm=InferenceEndpointsLLM(
|
@@ -188,21 +206,3 @@ def get_labeller_generator(system_prompt, labels, num_labels):
|
|
188 |
)
|
189 |
labeller_generator.load()
|
190 |
return labeller_generator
|
191 |
-
|
192 |
-
|
193 |
-
def get_prompt_generator():
|
194 |
-
prompt_generator = TextGeneration(
|
195 |
-
llm=InferenceEndpointsLLM(
|
196 |
-
api_key=_get_next_api_key(),
|
197 |
-
model_id=MODEL,
|
198 |
-
tokenizer_id=MODEL,
|
199 |
-
generation_kwargs={
|
200 |
-
"temperature": 0.8,
|
201 |
-
"max_new_tokens": 2048,
|
202 |
-
"do_sample": True,
|
203 |
-
},
|
204 |
-
),
|
205 |
-
use_system_prompt=True,
|
206 |
-
)
|
207 |
-
prompt_generator.load()
|
208 |
-
return prompt_generator
|
|
|
148 |
)
|
149 |
|
150 |
|
151 |
+
def get_prompt_generator(temperature):
|
152 |
+
prompt_generator = TextGeneration(
|
153 |
+
llm=InferenceEndpointsLLM(
|
154 |
+
api_key=_get_next_api_key(),
|
155 |
+
model_id=MODEL,
|
156 |
+
tokenizer_id=MODEL,
|
157 |
+
generation_kwargs={
|
158 |
+
"temperature": temperature,
|
159 |
+
"max_new_tokens": 2048,
|
160 |
+
"do_sample": True,
|
161 |
+
},
|
162 |
+
),
|
163 |
+
use_system_prompt=True,
|
164 |
+
)
|
165 |
+
prompt_generator.load()
|
166 |
+
return prompt_generator
|
167 |
+
|
168 |
+
|
169 |
def get_textcat_generator(difficulty, clarity, is_sample):
|
170 |
textcat_generator = GenerateTextClassificationData(
|
171 |
llm=InferenceEndpointsLLM(
|
|
|
206 |
)
|
207 |
labeller_generator.load()
|
208 |
return labeller_generator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|