davidberenstein1957 HF staff commited on
Commit
88a4065
Β·
1 Parent(s): ae92377

add support for running without argilla

Browse files
src/distilabel_dataset_generator/apps/eval.py CHANGED
@@ -39,9 +39,9 @@ from src.distilabel_dataset_generator.utils import (
39
  extract_column_names,
40
  get_argilla_client,
41
  get_org_dropdown,
 
42
  process_columns,
43
  swap_visibility,
44
- pad_or_truncate_list,
45
  )
46
 
47
 
@@ -334,8 +334,10 @@ def push_dataset(
334
  push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
335
  try:
336
  progress(0.1, desc="Setting up user and workspace")
337
- client = get_argilla_client()
338
  hf_user = HfApi().whoami(token=oauth_token.token)["name"]
 
 
 
339
  if eval_type == "ultrafeedback":
340
  num_generations = len((dataframe["generations"][0]))
341
  fields = [
@@ -580,6 +582,7 @@ def push_dataset(
580
  def show_pipeline_code_visibility():
581
  return {pipeline_code_ui: gr.Accordion(visible=True)}
582
 
 
583
  def hide_pipeline_code_visibility():
584
  return {pipeline_code_ui: gr.Accordion(visible=False)}
585
 
@@ -708,15 +711,15 @@ with gr.Blocks() as app:
708
  visible=False,
709
  ) as pipeline_code_ui:
710
  code = generate_pipeline_code(
711
- repo_id=search_in.value,
712
- aspects=aspects_instruction_response.value,
713
- instruction_column=instruction_instruction_response,
714
- response_columns=response_instruction_response,
715
- prompt_template=prompt_template.value,
716
- structured_output=structured_output.value,
717
- num_rows=num_rows.value,
718
- eval_type=eval_type.value,
719
- )
720
  pipeline_code = gr.Code(
721
  value=code,
722
  language="python",
 
39
  extract_column_names,
40
  get_argilla_client,
41
  get_org_dropdown,
42
+ pad_or_truncate_list,
43
  process_columns,
44
  swap_visibility,
 
45
  )
46
 
47
 
 
334
  push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
335
  try:
336
  progress(0.1, desc="Setting up user and workspace")
 
337
  hf_user = HfApi().whoami(token=oauth_token.token)["name"]
338
+ client = get_argilla_client()
339
+ if client is None:
340
+ return ""
341
  if eval_type == "ultrafeedback":
342
  num_generations = len((dataframe["generations"][0]))
343
  fields = [
 
582
  def show_pipeline_code_visibility():
583
  return {pipeline_code_ui: gr.Accordion(visible=True)}
584
 
585
+
586
  def hide_pipeline_code_visibility():
587
  return {pipeline_code_ui: gr.Accordion(visible=False)}
588
 
 
711
  visible=False,
712
  ) as pipeline_code_ui:
713
  code = generate_pipeline_code(
714
+ repo_id=search_in.value,
715
+ aspects=aspects_instruction_response.value,
716
+ instruction_column=instruction_instruction_response,
717
+ response_columns=response_instruction_response,
718
+ prompt_template=prompt_template.value,
719
+ structured_output=structured_output.value,
720
+ num_rows=num_rows.value,
721
+ eval_type=eval_type.value,
722
+ )
723
  pipeline_code = gr.Code(
724
  value=code,
725
  language="python",
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -220,8 +220,10 @@ def push_dataset(
220
  push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
221
  try:
222
  progress(0.1, desc="Setting up user and workspace")
223
- client = get_argilla_client()
224
  hf_user = HfApi().whoami(token=oauth_token.token)["name"]
 
 
 
225
  if "messages" in dataframe.columns:
226
  settings = rg.Settings(
227
  fields=[
 
220
  push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
221
  try:
222
  progress(0.1, desc="Setting up user and workspace")
 
223
  hf_user = HfApi().whoami(token=oauth_token.token)["name"]
224
+ client = get_argilla_client()
225
+ if client is None:
226
+ return ""
227
  if "messages" in dataframe.columns:
228
  settings = rg.Settings(
229
  fields=[
src/distilabel_dataset_generator/apps/textcat.py CHANGED
@@ -58,7 +58,10 @@ def generate_system_prompt(dataset_description, temperature, progress=gr.Progres
58
  labels = data["labels"]
59
  return system_prompt, labels
60
 
61
- def generate_sample_dataset(system_prompt, difficulty, clarity, labels, num_labels, progress=gr.Progress()):
 
 
 
62
  dataframe = generate_dataset(
63
  system_prompt=system_prompt,
64
  difficulty=difficulty,
@@ -138,11 +141,7 @@ def generate_dataset(
138
  # create final dataset
139
  distiset_results = []
140
  for result in labeller_results:
141
- record = {
142
- key: result[key]
143
- for key in ["labels", "text"]
144
- if key in result
145
- }
146
  distiset_results.append(record)
147
 
148
  dataframe = pd.DataFrame(distiset_results)
@@ -212,13 +211,16 @@ def push_dataset(
212
  push_dataset_to_hub(
213
  dataframe, org_name, repo_name, num_labels, labels, oauth_token, private
214
  )
 
215
  dataframe = dataframe[
216
  (dataframe["text"].str.strip() != "") & (dataframe["text"].notna())
217
  ]
218
  try:
219
  progress(0.1, desc="Setting up user and workspace")
220
- client = get_argilla_client()
221
  hf_user = HfApi().whoami(token=oauth_token.token)["name"]
 
 
 
222
  labels = get_preprocess_labels(labels)
223
  settings = rg.Settings(
224
  fields=[
 
58
  labels = data["labels"]
59
  return system_prompt, labels
60
 
61
+
62
+ def generate_sample_dataset(
63
+ system_prompt, difficulty, clarity, labels, num_labels, progress=gr.Progress()
64
+ ):
65
  dataframe = generate_dataset(
66
  system_prompt=system_prompt,
67
  difficulty=difficulty,
 
141
  # create final dataset
142
  distiset_results = []
143
  for result in labeller_results:
144
+ record = {key: result[key] for key in ["labels", "text"] if key in result}
 
 
 
 
145
  distiset_results.append(record)
146
 
147
  dataframe = pd.DataFrame(distiset_results)
 
211
  push_dataset_to_hub(
212
  dataframe, org_name, repo_name, num_labels, labels, oauth_token, private
213
  )
214
+
215
  dataframe = dataframe[
216
  (dataframe["text"].str.strip() != "") & (dataframe["text"].notna())
217
  ]
218
  try:
219
  progress(0.1, desc="Setting up user and workspace")
 
220
  hf_user = HfApi().whoami(token=oauth_token.token)["name"]
221
+ client = get_argilla_client()
222
+ if client is None:
223
+ return ""
224
  labels = get_preprocess_labels(labels)
225
  settings = rg.Settings(
226
  fields=[