update sft and use input parameters
Browse files
src/distilabel_dataset_generator/apps/sft.py
CHANGED
@@ -24,7 +24,6 @@ from src.distilabel_dataset_generator.pipelines.embeddings import (
|
|
24 |
)
|
25 |
from src.distilabel_dataset_generator.pipelines.sft import (
|
26 |
DEFAULT_DATASET_DESCRIPTIONS,
|
27 |
-
PROMPT_CREATION_PROMPT,
|
28 |
generate_pipeline_code,
|
29 |
get_magpie_generator,
|
30 |
get_prompt_generator,
|
@@ -55,7 +54,6 @@ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
|
|
55 |
|
56 |
def generate_system_prompt(dataset_description, temperature, progress=gr.Progress()):
|
57 |
progress(0.0, desc="Generating system prompt")
|
58 |
-
|
59 |
progress(0.3, desc="Initializing text generation")
|
60 |
generate_description = get_prompt_generator(temperature)
|
61 |
progress(0.7, desc="Generating system prompt")
|
@@ -63,20 +61,19 @@ def generate_system_prompt(dataset_description, temperature, progress=gr.Progres
|
|
63 |
generate_description.process(
|
64 |
[
|
65 |
{
|
66 |
-
"system_prompt": PROMPT_CREATION_PROMPT,
|
67 |
"instruction": dataset_description,
|
68 |
}
|
69 |
]
|
70 |
)
|
71 |
)[0]["generation"]
|
72 |
progress(1.0, desc="System prompt generated")
|
73 |
-
return result
|
74 |
|
75 |
|
76 |
-
def generate_sample_dataset(system_prompt, progress=gr.Progress()):
|
77 |
df = generate_dataset(
|
78 |
system_prompt=system_prompt,
|
79 |
-
num_turns=
|
80 |
num_rows=10,
|
81 |
progress=progress,
|
82 |
is_sample=True,
|
@@ -92,10 +89,8 @@ def generate_dataset(
|
|
92 |
progress=gr.Progress(),
|
93 |
) -> pd.DataFrame:
|
94 |
progress(0.0, desc="(1/2) Generating instructions")
|
95 |
-
magpie_generator = get_magpie_generator(
|
96 |
-
|
97 |
-
)
|
98 |
-
response_generator = get_response_generator(num_turns, system_prompt, is_sample)
|
99 |
total_steps: int = num_rows * 2
|
100 |
batch_size = DEFAULT_BATCH_SIZE
|
101 |
|
@@ -405,7 +400,7 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
|
|
405 |
)
|
406 |
with gr.Column(scale=3):
|
407 |
dataframe = gr.Dataframe(
|
408 |
-
headers=["prompt", "completion"], wrap=True, height=
|
409 |
)
|
410 |
|
411 |
gr.HTML(value="<hr>")
|
@@ -450,15 +445,21 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
|
|
450 |
label="Distilabel Pipeline Code",
|
451 |
)
|
452 |
|
453 |
-
|
454 |
-
triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
|
455 |
fn=generate_system_prompt,
|
456 |
inputs=[dataset_description, temperature],
|
457 |
-
outputs=[system_prompt
|
458 |
show_progress=True,
|
459 |
).then(
|
460 |
fn=generate_sample_dataset,
|
461 |
-
inputs=[system_prompt],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
462 |
outputs=[dataframe],
|
463 |
show_progress=True,
|
464 |
)
|
|
|
24 |
)
|
25 |
from src.distilabel_dataset_generator.pipelines.sft import (
|
26 |
DEFAULT_DATASET_DESCRIPTIONS,
|
|
|
27 |
generate_pipeline_code,
|
28 |
get_magpie_generator,
|
29 |
get_prompt_generator,
|
|
|
54 |
|
55 |
def generate_system_prompt(dataset_description, temperature, progress=gr.Progress()):
|
56 |
progress(0.0, desc="Generating system prompt")
|
|
|
57 |
progress(0.3, desc="Initializing text generation")
|
58 |
generate_description = get_prompt_generator(temperature)
|
59 |
progress(0.7, desc="Generating system prompt")
|
|
|
61 |
generate_description.process(
|
62 |
[
|
63 |
{
|
|
|
64 |
"instruction": dataset_description,
|
65 |
}
|
66 |
]
|
67 |
)
|
68 |
)[0]["generation"]
|
69 |
progress(1.0, desc="System prompt generated")
|
70 |
+
return result
|
71 |
|
72 |
|
73 |
+
def generate_sample_dataset(system_prompt, num_turns, progress=gr.Progress()):
|
74 |
df = generate_dataset(
|
75 |
system_prompt=system_prompt,
|
76 |
+
num_turns=num_turns,
|
77 |
num_rows=10,
|
78 |
progress=progress,
|
79 |
is_sample=True,
|
|
|
89 |
progress=gr.Progress(),
|
90 |
) -> pd.DataFrame:
|
91 |
progress(0.0, desc="(1/2) Generating instructions")
|
92 |
+
magpie_generator = get_magpie_generator(system_prompt, num_turns, is_sample)
|
93 |
+
response_generator = get_response_generator(system_prompt, num_turns, is_sample)
|
|
|
|
|
94 |
total_steps: int = num_rows * 2
|
95 |
batch_size = DEFAULT_BATCH_SIZE
|
96 |
|
|
|
400 |
)
|
401 |
with gr.Column(scale=3):
|
402 |
dataframe = gr.Dataframe(
|
403 |
+
headers=["prompt", "completion"], wrap=True, height=500, interactive=False
|
404 |
)
|
405 |
|
406 |
gr.HTML(value="<hr>")
|
|
|
445 |
label="Distilabel Pipeline Code",
|
446 |
)
|
447 |
|
448 |
+
load_btn.click(
|
|
|
449 |
fn=generate_system_prompt,
|
450 |
inputs=[dataset_description, temperature],
|
451 |
+
outputs=[system_prompt],
|
452 |
show_progress=True,
|
453 |
).then(
|
454 |
fn=generate_sample_dataset,
|
455 |
+
inputs=[system_prompt, num_turns],
|
456 |
+
outputs=[dataframe],
|
457 |
+
show_progress=True,
|
458 |
+
)
|
459 |
+
|
460 |
+
btn_apply_to_sample_dataset.click(
|
461 |
+
fn=generate_sample_dataset,
|
462 |
+
inputs=[system_prompt, num_turns],
|
463 |
outputs=[dataframe],
|
464 |
show_progress=True,
|
465 |
)
|
src/distilabel_dataset_generator/pipelines/sft.py
CHANGED
@@ -138,51 +138,6 @@ def _get_output_mappings(num_turns):
|
|
138 |
return {"conversation": "messages"}
|
139 |
|
140 |
|
141 |
-
def generate_pipeline_code(system_prompt, num_turns, num_rows):
|
142 |
-
input_mappings = _get_output_mappings(num_turns)
|
143 |
-
code = f"""
|
144 |
-
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
145 |
-
import os
|
146 |
-
from distilabel.pipeline import Pipeline
|
147 |
-
from distilabel.steps import KeepColumns
|
148 |
-
from distilabel.steps.tasks import MagpieGenerator
|
149 |
-
from distilabel.llms import InferenceEndpointsLLM
|
150 |
-
|
151 |
-
MODEL = "{MODEL}"
|
152 |
-
SYSTEM_PROMPT = "{system_prompt}"
|
153 |
-
os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
|
154 |
-
|
155 |
-
with Pipeline(name="sft") as pipeline:
|
156 |
-
magpie = MagpieGenerator(
|
157 |
-
llm=InferenceEndpointsLLM(
|
158 |
-
model_id=MODEL,
|
159 |
-
tokenizer_id=MODEL,
|
160 |
-
magpie_pre_query_template="llama3",
|
161 |
-
generation_kwargs={{
|
162 |
-
"temperature": 0.9,
|
163 |
-
"do_sample": True,
|
164 |
-
"max_new_tokens": 2048,
|
165 |
-
"stop_sequences": {_STOP_SEQUENCES}
|
166 |
-
}},
|
167 |
-
api_key=os.environ["HF_TOKEN"],
|
168 |
-
),
|
169 |
-
n_turns={num_turns},
|
170 |
-
num_rows={num_rows},
|
171 |
-
batch_size=1,
|
172 |
-
system_prompt=SYSTEM_PROMPT,
|
173 |
-
output_mappings={input_mappings},
|
174 |
-
)
|
175 |
-
keep_columns = KeepColumns(
|
176 |
-
columns={list(input_mappings.values())} + ["model_name"],
|
177 |
-
)
|
178 |
-
magpie.connect(keep_columns)
|
179 |
-
|
180 |
-
if __name__ == "__main__":
|
181 |
-
distiset = pipeline.run()
|
182 |
-
"""
|
183 |
-
return code
|
184 |
-
|
185 |
-
|
186 |
def get_prompt_generator(temperature):
|
187 |
prompt_generator = TextGeneration(
|
188 |
llm=InferenceEndpointsLLM(
|
@@ -195,13 +150,14 @@ def get_prompt_generator(temperature):
|
|
195 |
"do_sample": True,
|
196 |
},
|
197 |
),
|
|
|
198 |
use_system_prompt=True,
|
199 |
)
|
200 |
prompt_generator.load()
|
201 |
return prompt_generator
|
202 |
|
203 |
|
204 |
-
def get_magpie_generator(
|
205 |
input_mappings = _get_output_mappings(num_turns)
|
206 |
output_mappings = input_mappings.copy()
|
207 |
if num_turns == 1:
|
@@ -246,7 +202,7 @@ def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample):
|
|
246 |
return magpie_generator
|
247 |
|
248 |
|
249 |
-
def get_response_generator(
|
250 |
if num_turns == 1:
|
251 |
response_generator = TextGeneration(
|
252 |
llm=InferenceEndpointsLLM(
|
@@ -278,3 +234,48 @@ def get_response_generator(num_turns, system_prompt, is_sample):
|
|
278 |
)
|
279 |
response_generator.load()
|
280 |
return response_generator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
return {"conversation": "messages"}
|
139 |
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
def get_prompt_generator(temperature):
|
142 |
prompt_generator = TextGeneration(
|
143 |
llm=InferenceEndpointsLLM(
|
|
|
150 |
"do_sample": True,
|
151 |
},
|
152 |
),
|
153 |
+
system_prompt=PROMPT_CREATION_PROMPT,
|
154 |
use_system_prompt=True,
|
155 |
)
|
156 |
prompt_generator.load()
|
157 |
return prompt_generator
|
158 |
|
159 |
|
160 |
+
def get_magpie_generator(system_prompt, num_turns, is_sample):
|
161 |
input_mappings = _get_output_mappings(num_turns)
|
162 |
output_mappings = input_mappings.copy()
|
163 |
if num_turns == 1:
|
|
|
202 |
return magpie_generator
|
203 |
|
204 |
|
205 |
+
def get_response_generator(system_prompt, num_turns, is_sample):
|
206 |
if num_turns == 1:
|
207 |
response_generator = TextGeneration(
|
208 |
llm=InferenceEndpointsLLM(
|
|
|
234 |
)
|
235 |
response_generator.load()
|
236 |
return response_generator
|
237 |
+
|
238 |
+
|
239 |
+
def generate_pipeline_code(system_prompt, num_turns, num_rows):
|
240 |
+
input_mappings = _get_output_mappings(num_turns)
|
241 |
+
code = f"""
|
242 |
+
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
243 |
+
import os
|
244 |
+
from distilabel.pipeline import Pipeline
|
245 |
+
from distilabel.steps import KeepColumns
|
246 |
+
from distilabel.steps.tasks import MagpieGenerator
|
247 |
+
from distilabel.llms import InferenceEndpointsLLM
|
248 |
+
|
249 |
+
MODEL = "{MODEL}"
|
250 |
+
SYSTEM_PROMPT = "{system_prompt}"
|
251 |
+
os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
|
252 |
+
|
253 |
+
with Pipeline(name="sft") as pipeline:
|
254 |
+
magpie = MagpieGenerator(
|
255 |
+
llm=InferenceEndpointsLLM(
|
256 |
+
model_id=MODEL,
|
257 |
+
tokenizer_id=MODEL,
|
258 |
+
magpie_pre_query_template="llama3",
|
259 |
+
generation_kwargs={{
|
260 |
+
"temperature": 0.9,
|
261 |
+
"do_sample": True,
|
262 |
+
"max_new_tokens": 2048,
|
263 |
+
"stop_sequences": {_STOP_SEQUENCES}
|
264 |
+
}},
|
265 |
+
api_key=os.environ["HF_TOKEN"],
|
266 |
+
),
|
267 |
+
n_turns={num_turns},
|
268 |
+
num_rows={num_rows},
|
269 |
+
batch_size=1,
|
270 |
+
system_prompt=SYSTEM_PROMPT,
|
271 |
+
output_mappings={input_mappings},
|
272 |
+
)
|
273 |
+
keep_columns = KeepColumns(
|
274 |
+
columns={list(input_mappings.values())} + ["model_name"],
|
275 |
+
)
|
276 |
+
magpie.connect(keep_columns)
|
277 |
+
|
278 |
+
if __name__ == "__main__":
|
279 |
+
distiset = pipeline.run()
|
280 |
+
"""
|
281 |
+
return code
|