davidberenstein1957 HF staff commited on
Commit
cd47483
·
1 Parent(s): 0202688

add support for custom BASE_URL, MODEL, APIKEY

Browse files
README.md CHANGED
@@ -80,7 +80,13 @@ pip install synthetic-dataset-generator
80
 
81
  ### Environment Variables
82
 
83
- - `HF_TOKEN`: Your Hugging Face token to push your datasets to the Hugging Face Hub and run *Free* Inference Endpoints Requests. You can get one [here](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained).
 
 
 
 
 
 
84
 
85
  Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
86
 
 
80
 
81
  ### Environment Variables
82
 
83
+ - `HF_TOKEN`: Your [Hugging Face token](https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&tokenType=fineGrained) to push your datasets to the Hugging Face Hub and generate free completions from Hugging Face Inference Endpoints.
84
+
85
+ Optionally, you can set the following environment variables to customize the generation process.
86
+
87
+ - `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api-inference.huggingface.co/v1/`.
88
+ - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`.
89
+ - `API_KEY`: The API key to use for the corresponding API, e.g. `hf_...`.
90
 
91
  Optionally, you can also push your datasets to Argilla for further curation by setting the following environment variables:
92
 
app.py CHANGED
@@ -1,8 +1,8 @@
1
- from src.distilabel_dataset_generator._tabbedinterface import TabbedInterface
2
- from src.distilabel_dataset_generator.apps.eval import app as eval_app
3
- from src.distilabel_dataset_generator.apps.faq import app as faq_app
4
- from src.distilabel_dataset_generator.apps.sft import app as sft_app
5
- from src.distilabel_dataset_generator.apps.textcat import app as textcat_app
6
 
7
  theme = "argilla/argilla-theme"
8
 
 
1
+ from distilabel_dataset_generator._tabbedinterface import TabbedInterface
2
+ from distilabel_dataset_generator.apps.eval import app as eval_app
3
+ from distilabel_dataset_generator.apps.faq import app as faq_app
4
+ from distilabel_dataset_generator.apps.sft import app as sft_app
5
+ from distilabel_dataset_generator.apps.textcat import app as textcat_app
6
 
7
  theme = "argilla/argilla-theme"
8
 
pyproject.toml CHANGED
@@ -5,6 +5,18 @@ description = "Build datasets using natural language"
5
  authors = [
6
  {name = "davidberenstein1957", email = "david.m.berenstein@gmail.com"},
7
  ]
 
 
 
 
 
 
 
 
 
 
 
 
8
  dependencies = [
9
  "distilabel[hf-inference-endpoints,argilla,outlines,instructor]>=1.4.1",
10
  "gradio[oauth]<5.0.0",
@@ -14,14 +26,10 @@ dependencies = [
14
  "gradio-huggingfacehub-search>=0.0.7",
15
  "argilla>=2.4.0",
16
  ]
17
- requires-python = "<3.13,>=3.10"
18
- readme = "README.md"
19
- license = {text = "apache 2"}
20
 
21
  [build-system]
22
  requires = ["pdm-backend"]
23
  build-backend = "pdm.backend"
24
 
25
-
26
  [tool.pdm]
27
  distribution = true
 
5
  authors = [
6
  {name = "davidberenstein1957", email = "david.m.berenstein@gmail.com"},
7
  ]
8
+ tags = [
9
+ "gradio",
10
+ "synthetic-data",
11
+ "huggingface",
12
+ "argilla",
13
+ "generative-ai",
14
+ "ai",
15
+ ]
16
+ requires-python = "<3.13,>=3.10"
17
+ readme = "README.md"
18
+ license = {text = "Apache 2"}
19
+
20
  dependencies = [
21
  "distilabel[hf-inference-endpoints,argilla,outlines,instructor]>=1.4.1",
22
  "gradio[oauth]<5.0.0",
 
26
  "gradio-huggingfacehub-search>=0.0.7",
27
  "argilla>=2.4.0",
28
  ]
 
 
 
29
 
30
  [build-system]
31
  requires = ["pdm-backend"]
32
  build-backend = "pdm.backend"
33
 
 
34
  [tool.pdm]
35
  distribution = true
src/distilabel_dataset_generator/__init__.py CHANGED
@@ -1,8 +1,5 @@
1
- import os
2
- import warnings
3
  from typing import Optional
4
 
5
- import argilla as rg
6
  import distilabel
7
  import distilabel.distiset
8
  from distilabel.utils.card.dataset_card import (
@@ -11,29 +8,6 @@ from distilabel.utils.card.dataset_card import (
11
  )
12
  from huggingface_hub import DatasetCardData, HfApi
13
 
14
- HF_TOKENS = [os.getenv("HF_TOKEN")] + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
15
- HF_TOKENS = [token for token in HF_TOKENS if token]
16
-
17
- if len(HF_TOKENS) == 0:
18
- raise ValueError(
19
- "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
20
- )
21
-
22
- ARGILLA_API_URL = os.getenv("ARGILLA_API_URL")
23
- ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY")
24
- if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
25
- ARGILLA_API_URL = os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
26
- ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER")
27
-
28
- if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
29
- warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set")
30
- argilla_client = None
31
- else:
32
- argilla_client = rg.Argilla(
33
- api_url=ARGILLA_API_URL,
34
- api_key=ARGILLA_API_KEY,
35
- )
36
-
37
 
38
  class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
39
  def _generate_card(
 
 
 
1
  from typing import Optional
2
 
 
3
  import distilabel
4
  import distilabel.distiset
5
  from distilabel.utils.card.dataset_card import (
 
8
  )
9
  from huggingface_hub import DatasetCardData, HfApi
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  class CustomDistisetWithAdditionalTag(distilabel.distiset.Distiset):
13
  def _generate_card(
src/distilabel_dataset_generator/apps/__init__.py ADDED
File without changes
src/distilabel_dataset_generator/apps/base.py CHANGED
@@ -10,7 +10,7 @@ from distilabel.distiset import Distiset
10
  from gradio import OAuthToken
11
  from huggingface_hub import HfApi, upload_file
12
 
13
- from src.distilabel_dataset_generator.utils import (
14
  _LOGGED_OUT_CSS,
15
  get_argilla_client,
16
  get_login_button,
 
10
  from gradio import OAuthToken
11
  from huggingface_hub import HfApi, upload_file
12
 
13
+ from distilabel_dataset_generator.utils import (
14
  _LOGGED_OUT_CSS,
15
  get_argilla_client,
16
  get_login_button,
src/distilabel_dataset_generator/apps/eval.py CHANGED
@@ -16,25 +16,23 @@ from distilabel.distiset import Distiset
16
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
17
  from huggingface_hub import HfApi
18
 
19
- from src.distilabel_dataset_generator.apps.base import (
20
  hide_success_message,
21
  show_success_message,
22
  validate_argilla_user_workspace_dataset,
23
  validate_push_to_hub,
24
  )
25
- from src.distilabel_dataset_generator.pipelines.base import (
26
- DEFAULT_BATCH_SIZE,
27
- )
28
- from src.distilabel_dataset_generator.pipelines.embeddings import (
29
  get_embeddings,
30
  get_sentence_embedding_dimensions,
31
  )
32
- from src.distilabel_dataset_generator.pipelines.eval import (
33
  generate_pipeline_code,
34
  get_custom_evaluator,
35
  get_ultrafeedback_evaluator,
36
  )
37
- from src.distilabel_dataset_generator.utils import (
38
  column_to_list,
39
  extract_column_names,
40
  get_argilla_client,
 
16
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
17
  from huggingface_hub import HfApi
18
 
19
+ from distilabel_dataset_generator.apps.base import (
20
  hide_success_message,
21
  show_success_message,
22
  validate_argilla_user_workspace_dataset,
23
  validate_push_to_hub,
24
  )
25
+ from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE
26
+ from distilabel_dataset_generator.pipelines.embeddings import (
 
 
27
  get_embeddings,
28
  get_sentence_embedding_dimensions,
29
  )
30
+ from distilabel_dataset_generator.pipelines.eval import (
31
  generate_pipeline_code,
32
  get_custom_evaluator,
33
  get_ultrafeedback_evaluator,
34
  )
35
+ from distilabel_dataset_generator.utils import (
36
  column_to_list,
37
  extract_column_names,
38
  get_argilla_client,
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -9,27 +9,25 @@ from datasets import Dataset
9
  from distilabel.distiset import Distiset
10
  from huggingface_hub import HfApi
11
 
12
- from src.distilabel_dataset_generator.apps.base import (
13
  hide_success_message,
14
  show_success_message,
15
  validate_argilla_user_workspace_dataset,
16
  validate_push_to_hub,
17
  )
18
- from src.distilabel_dataset_generator.pipelines.base import (
19
- DEFAULT_BATCH_SIZE,
20
- )
21
- from src.distilabel_dataset_generator.pipelines.embeddings import (
22
  get_embeddings,
23
  get_sentence_embedding_dimensions,
24
  )
25
- from src.distilabel_dataset_generator.pipelines.sft import (
26
  DEFAULT_DATASET_DESCRIPTIONS,
27
  generate_pipeline_code,
28
  get_magpie_generator,
29
  get_prompt_generator,
30
  get_response_generator,
31
  )
32
- from src.distilabel_dataset_generator.utils import (
33
  _LOGGED_OUT_CSS,
34
  get_argilla_client,
35
  get_org_dropdown,
@@ -354,168 +352,175 @@ def hide_pipeline_code_visibility():
354
 
355
  with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
356
  with gr.Column() as main_ui:
357
- gr.Markdown(value="## 1. Describe the dataset you want")
358
- with gr.Row():
359
- with gr.Column(scale=2):
360
- dataset_description = gr.Textbox(
361
- label="Dataset description",
362
- placeholder="Give a precise description of your desired dataset.",
363
- )
364
- with gr.Accordion("Temperature", open=False):
365
- temperature = gr.Slider(
366
- minimum=0.1,
367
- maximum=1,
368
- value=0.8,
369
- step=0.1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  interactive=True,
371
- show_label=False,
372
  )
373
- load_btn = gr.Button(
374
- "Create dataset",
375
- variant="primary",
376
- )
377
- with gr.Column(scale=2):
378
- examples = gr.Examples(
379
- examples=DEFAULT_DATASET_DESCRIPTIONS,
380
- inputs=[dataset_description],
381
- cache_examples=False,
382
- label="Examples",
383
- )
384
- with gr.Column(scale=1):
385
- pass
386
-
387
- gr.HTML(value="<hr>")
388
- gr.Markdown(value="## 2. Configure your dataset")
389
- with gr.Row(equal_height=False):
390
- with gr.Column(scale=2):
391
- system_prompt = gr.Textbox(
392
- label="System prompt",
393
- placeholder="You are a helpful assistant.",
394
- )
395
- num_turns = gr.Number(
396
- value=1,
397
- label="Number of turns in the conversation",
398
- minimum=1,
399
- maximum=4,
400
- step=1,
401
- interactive=True,
402
- info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
403
- )
404
- btn_apply_to_sample_dataset = gr.Button(
405
- "Refresh dataset", variant="secondary"
406
- )
407
- with gr.Column(scale=3):
408
- dataframe = gr.Dataframe(
409
- headers=["prompt", "completion"],
410
- wrap=True,
411
- height=500,
412
- interactive=False,
413
- )
414
-
415
- gr.HTML(value="<hr>")
416
- gr.Markdown(value="## 3. Generate your dataset")
417
- with gr.Row(equal_height=False):
418
- with gr.Column(scale=2):
419
- org_name = get_org_dropdown()
420
- repo_name = gr.Textbox(
421
- label="Repo name",
422
- placeholder="dataset_name",
423
- value=f"my-distiset-{str(uuid.uuid4())[:8]}",
424
- interactive=True,
425
- )
426
- num_rows = gr.Number(
427
- label="Number of rows",
428
- value=10,
429
- interactive=True,
430
- scale=1,
431
- )
432
- private = gr.Checkbox(
433
- label="Private dataset",
434
- value=False,
435
- interactive=True,
436
- scale=1,
437
- )
438
- btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
439
- with gr.Column(scale=3):
440
- success_message = gr.Markdown(visible=True)
441
- with gr.Accordion(
442
- "Do you want to go further? Customize and run with Distilabel",
443
- open=False,
444
- visible=False,
445
- ) as pipeline_code_ui:
446
- code = generate_pipeline_code(
447
- system_prompt=system_prompt.value,
448
- num_turns=num_turns.value,
449
- num_rows=num_rows.value,
450
  )
451
- pipeline_code = gr.Code(
452
- value=code,
453
- language="python",
454
- label="Distilabel Pipeline Code",
 
 
455
  )
456
 
457
- load_btn.click(
458
- fn=generate_system_prompt,
459
- inputs=[dataset_description, temperature],
460
- outputs=[system_prompt],
461
- show_progress=True,
462
- ).then(
463
- fn=generate_sample_dataset,
464
- inputs=[system_prompt, num_turns],
465
- outputs=[dataframe],
466
- show_progress=True,
467
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
- btn_apply_to_sample_dataset.click(
470
- fn=generate_sample_dataset,
471
- inputs=[system_prompt, num_turns],
472
- outputs=[dataframe],
473
- show_progress=True,
474
- )
475
 
476
- btn_push_to_hub.click(
477
- fn=validate_argilla_user_workspace_dataset,
478
- inputs=[repo_name],
479
- outputs=[success_message],
480
- show_progress=True,
481
- ).then(
482
- fn=validate_push_to_hub,
483
- inputs=[org_name, repo_name],
484
- outputs=[success_message],
485
- show_progress=True,
486
- ).success(
487
- fn=hide_success_message,
488
- outputs=[success_message],
489
- show_progress=True,
490
- ).success(
491
- fn=hide_pipeline_code_visibility,
492
- inputs=[],
493
- outputs=[pipeline_code_ui],
494
- ).success(
495
- fn=push_dataset,
496
- inputs=[
497
- org_name,
498
- repo_name,
499
- system_prompt,
500
- num_turns,
501
- num_rows,
502
- private,
503
- ],
504
- outputs=[success_message],
505
- show_progress=True,
506
- ).success(
507
- fn=show_success_message,
508
- inputs=[org_name, repo_name],
509
- outputs=[success_message],
510
- ).success(
511
- fn=generate_pipeline_code,
512
- inputs=[system_prompt, num_turns, num_rows],
513
- outputs=[pipeline_code],
514
- ).success(
515
- fn=show_pipeline_code_visibility,
516
- inputs=[],
517
- outputs=[pipeline_code_ui],
518
- )
519
 
520
- app.load(fn=swap_visibility, outputs=main_ui)
521
- app.load(fn=get_org_dropdown, outputs=[org_name])
 
9
  from distilabel.distiset import Distiset
10
  from huggingface_hub import HfApi
11
 
12
+ from distilabel_dataset_generator.apps.base import (
13
  hide_success_message,
14
  show_success_message,
15
  validate_argilla_user_workspace_dataset,
16
  validate_push_to_hub,
17
  )
18
+ from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE, SFT_AVAILABLE
19
+ from distilabel_dataset_generator.pipelines.embeddings import (
 
 
20
  get_embeddings,
21
  get_sentence_embedding_dimensions,
22
  )
23
+ from distilabel_dataset_generator.pipelines.sft import (
24
  DEFAULT_DATASET_DESCRIPTIONS,
25
  generate_pipeline_code,
26
  get_magpie_generator,
27
  get_prompt_generator,
28
  get_response_generator,
29
  )
30
+ from distilabel_dataset_generator.utils import (
31
  _LOGGED_OUT_CSS,
32
  get_argilla_client,
33
  get_org_dropdown,
 
352
 
353
  with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
354
  with gr.Column() as main_ui:
355
+ if not SFT_AVAILABLE:
356
+ gr.Markdown(
357
+ value=f"## Supervised Fine-Tuning is not available for the {MODEL} model. Use Hugging Face Llama3 or Qwen2 models."
358
+ )
359
+ else:
360
+ gr.Markdown(value="## 1. Describe the dataset you want")
361
+ with gr.Row():
362
+ with gr.Column(scale=2):
363
+ dataset_description = gr.Textbox(
364
+ label="Dataset description",
365
+ placeholder="Give a precise description of your desired dataset.",
366
+ )
367
+ with gr.Accordion("Temperature", open=False):
368
+ temperature = gr.Slider(
369
+ minimum=0.1,
370
+ maximum=1,
371
+ value=0.8,
372
+ step=0.1,
373
+ interactive=True,
374
+ show_label=False,
375
+ )
376
+ load_btn = gr.Button(
377
+ "Create dataset",
378
+ variant="primary",
379
+ )
380
+ with gr.Column(scale=2):
381
+ examples = gr.Examples(
382
+ examples=DEFAULT_DATASET_DESCRIPTIONS,
383
+ inputs=[dataset_description],
384
+ cache_examples=False,
385
+ label="Examples",
386
+ )
387
+ with gr.Column(scale=1):
388
+ pass
389
+
390
+ gr.HTML(value="<hr>")
391
+ gr.Markdown(value="## 2. Configure your dataset")
392
+ with gr.Row(equal_height=False):
393
+ with gr.Column(scale=2):
394
+ system_prompt = gr.Textbox(
395
+ label="System prompt",
396
+ placeholder="You are a helpful assistant.",
397
+ )
398
+ num_turns = gr.Number(
399
+ value=1,
400
+ label="Number of turns in the conversation",
401
+ minimum=1,
402
+ maximum=4,
403
+ step=1,
404
  interactive=True,
405
+ info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
406
  )
407
+ btn_apply_to_sample_dataset = gr.Button(
408
+ "Refresh dataset", variant="secondary"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  )
410
+ with gr.Column(scale=3):
411
+ dataframe = gr.Dataframe(
412
+ headers=["prompt", "completion"],
413
+ wrap=True,
414
+ height=500,
415
+ interactive=False,
416
  )
417
 
418
+ gr.HTML(value="<hr>")
419
+ gr.Markdown(value="## 3. Generate your dataset")
420
+ with gr.Row(equal_height=False):
421
+ with gr.Column(scale=2):
422
+ org_name = get_org_dropdown()
423
+ repo_name = gr.Textbox(
424
+ label="Repo name",
425
+ placeholder="dataset_name",
426
+ value=f"my-distiset-{str(uuid.uuid4())[:8]}",
427
+ interactive=True,
428
+ )
429
+ num_rows = gr.Number(
430
+ label="Number of rows",
431
+ value=10,
432
+ interactive=True,
433
+ scale=1,
434
+ )
435
+ private = gr.Checkbox(
436
+ label="Private dataset",
437
+ value=False,
438
+ interactive=True,
439
+ scale=1,
440
+ )
441
+ btn_push_to_hub = gr.Button(
442
+ "Push to Hub", variant="primary", scale=2
443
+ )
444
+ with gr.Column(scale=3):
445
+ success_message = gr.Markdown(visible=True)
446
+ with gr.Accordion(
447
+ "Do you want to go further? Customize and run with Distilabel",
448
+ open=False,
449
+ visible=False,
450
+ ) as pipeline_code_ui:
451
+ code = generate_pipeline_code(
452
+ system_prompt=system_prompt.value,
453
+ num_turns=num_turns.value,
454
+ num_rows=num_rows.value,
455
+ )
456
+ pipeline_code = gr.Code(
457
+ value=code,
458
+ language="python",
459
+ label="Distilabel Pipeline Code",
460
+ )
461
+
462
+ load_btn.click(
463
+ fn=generate_system_prompt,
464
+ inputs=[dataset_description, temperature],
465
+ outputs=[system_prompt],
466
+ show_progress=True,
467
+ ).then(
468
+ fn=generate_sample_dataset,
469
+ inputs=[system_prompt, num_turns],
470
+ outputs=[dataframe],
471
+ show_progress=True,
472
+ )
473
 
474
+ btn_apply_to_sample_dataset.click(
475
+ fn=generate_sample_dataset,
476
+ inputs=[system_prompt, num_turns],
477
+ outputs=[dataframe],
478
+ show_progress=True,
479
+ )
480
 
481
+ btn_push_to_hub.click(
482
+ fn=validate_argilla_user_workspace_dataset,
483
+ inputs=[repo_name],
484
+ outputs=[success_message],
485
+ show_progress=True,
486
+ ).then(
487
+ fn=validate_push_to_hub,
488
+ inputs=[org_name, repo_name],
489
+ outputs=[success_message],
490
+ show_progress=True,
491
+ ).success(
492
+ fn=hide_success_message,
493
+ outputs=[success_message],
494
+ show_progress=True,
495
+ ).success(
496
+ fn=hide_pipeline_code_visibility,
497
+ inputs=[],
498
+ outputs=[pipeline_code_ui],
499
+ ).success(
500
+ fn=push_dataset,
501
+ inputs=[
502
+ org_name,
503
+ repo_name,
504
+ system_prompt,
505
+ num_turns,
506
+ num_rows,
507
+ private,
508
+ ],
509
+ outputs=[success_message],
510
+ show_progress=True,
511
+ ).success(
512
+ fn=show_success_message,
513
+ inputs=[org_name, repo_name],
514
+ outputs=[success_message],
515
+ ).success(
516
+ fn=generate_pipeline_code,
517
+ inputs=[system_prompt, num_turns, num_rows],
518
+ outputs=[pipeline_code],
519
+ ).success(
520
+ fn=show_pipeline_code_visibility,
521
+ inputs=[],
522
+ outputs=[pipeline_code_ui],
523
+ )
524
 
525
+ app.load(fn=swap_visibility, outputs=main_ui)
526
+ app.load(fn=get_org_dropdown, outputs=[org_name])
src/distilabel_dataset_generator/apps/textcat.py CHANGED
@@ -9,15 +9,13 @@ from datasets import ClassLabel, Dataset, Features, Sequence, Value
9
  from distilabel.distiset import Distiset
10
  from huggingface_hub import HfApi
11
 
 
12
  from src.distilabel_dataset_generator.apps.base import (
13
  hide_success_message,
14
  show_success_message,
15
  validate_argilla_user_workspace_dataset,
16
  validate_push_to_hub,
17
  )
18
- from src.distilabel_dataset_generator.pipelines.base import (
19
- DEFAULT_BATCH_SIZE,
20
- )
21
  from src.distilabel_dataset_generator.pipelines.embeddings import (
22
  get_embeddings,
23
  get_sentence_embedding_dimensions,
 
9
  from distilabel.distiset import Distiset
10
  from huggingface_hub import HfApi
11
 
12
+ from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE
13
  from src.distilabel_dataset_generator.apps.base import (
14
  hide_success_message,
15
  show_success_message,
16
  validate_argilla_user_workspace_dataset,
17
  validate_push_to_hub,
18
  )
 
 
 
19
  from src.distilabel_dataset_generator.pipelines.embeddings import (
20
  get_embeddings,
21
  get_sentence_embedding_dimensions,
src/distilabel_dataset_generator/constants.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+
4
+ import argilla as rg
5
+
6
+ # Hugging Face
7
+ HF_TOKEN = os.getenv("HF_TOKEN")
8
+ if HF_TOKEN is None:
9
+ raise ValueError(
10
+ "HF_TOKEN is not set. Ensure you have set the HF_TOKEN environment variable that has access to the Hugging Face Hub repositories and Inference Endpoints."
11
+ )
12
+
13
+ # Inference
14
+ DEFAULT_BATCH_SIZE = 5
15
+ MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
16
+ API_KEYS = (
17
+ [os.getenv("HF_TOKEN")]
18
+ + [os.getenv(f"HF_TOKEN_{i}") for i in range(1, 10)]
19
+ + [os.getenv("API_KEY")]
20
+ )
21
+ API_KEYS = [token for token in API_KEYS if token]
22
+ BASE_URL = os.getenv("BASE_URL", "https://api-inference.huggingface.co/v1/")
23
+
24
+ if BASE_URL != "https://api-inference.huggingface.co/v1/" and len(API_KEYS) == 0:
25
+ raise ValueError(
26
+ "API_KEY is not set. Ensure you have set the API_KEY environment variable that has access to the Hugging Face Inference Endpoints."
27
+ )
28
+ if "Qwen2" not in MODEL and "Llama-3" not in MODEL:
29
+ SFT_AVAILABLE = False
30
+ warnings.warn(
31
+ "SFT_AVAILABLE is set to False because the model is not a Qwen or Llama model."
32
+ )
33
+ MAGPIE_PRE_QUERY_TEMPLATE = None
34
+ else:
35
+ SFT_AVAILABLE = True
36
+ if "Qwen2" in MODEL:
37
+ MAGPIE_PRE_QUERY_TEMPLATE = "qwen2"
38
+ else:
39
+ MAGPIE_PRE_QUERY_TEMPLATE = "llama3"
40
+
41
+ # Argilla
42
+ ARGILLA_API_URL = os.getenv("ARGILLA_API_URL")
43
+ ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY")
44
+ if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
45
+ ARGILLA_API_URL = os.getenv("ARGILLA_API_URL_SDG_REVIEWER")
46
+ ARGILLA_API_KEY = os.getenv("ARGILLA_API_KEY_SDG_REVIEWER")
47
+
48
+ if ARGILLA_API_URL is None or ARGILLA_API_KEY is None:
49
+ warnings.warn("ARGILLA_API_URL or ARGILLA_API_KEY is not set")
50
+ argilla_client = None
51
+ else:
52
+ argilla_client = rg.Argilla(
53
+ api_url=ARGILLA_API_URL,
54
+ api_key=ARGILLA_API_KEY,
55
+ )
src/distilabel_dataset_generator/pipelines/__init__.py ADDED
File without changes
src/distilabel_dataset_generator/pipelines/base.py CHANGED
@@ -1,12 +1,10 @@
1
- from src.distilabel_dataset_generator import HF_TOKENS
2
 
3
- DEFAULT_BATCH_SIZE = 5
4
  TOKEN_INDEX = 0
5
- MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
6
 
7
 
8
  def _get_next_api_key():
9
  global TOKEN_INDEX
10
- api_key = HF_TOKENS[TOKEN_INDEX % len(HF_TOKENS)]
11
  TOKEN_INDEX += 1
12
  return api_key
 
1
+ from distilabel_dataset_generator.constants import API_KEYS
2
 
 
3
  TOKEN_INDEX = 0
 
4
 
5
 
6
  def _get_next_api_key():
7
  global TOKEN_INDEX
8
+ api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
9
  TOKEN_INDEX += 1
10
  return api_key
src/distilabel_dataset_generator/pipelines/embeddings.py CHANGED
@@ -4,7 +4,7 @@ from sentence_transformers import SentenceTransformer
4
  from sentence_transformers.models import StaticEmbedding
5
 
6
  # Initialize a StaticEmbedding module
7
- static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_base_output")
8
  model = SentenceTransformer(modules=[static_embedding])
9
 
10
 
 
4
  from sentence_transformers.models import StaticEmbedding
5
 
6
  # Initialize a StaticEmbedding module
7
+ static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-8M")
8
  model = SentenceTransformer(modules=[static_embedding])
9
 
10
 
src/distilabel_dataset_generator/pipelines/eval.py CHANGED
@@ -5,18 +5,16 @@ from distilabel.steps.tasks import (
5
  UltraFeedback,
6
  )
7
 
8
- from src.distilabel_dataset_generator.pipelines.base import (
9
- MODEL,
10
- _get_next_api_key,
11
- )
12
- from src.distilabel_dataset_generator.utils import extract_column_names
13
 
14
 
15
  def get_ultrafeedback_evaluator(aspect, is_sample):
16
  ultrafeedback_evaluator = UltraFeedback(
17
  llm=InferenceEndpointsLLM(
18
  model_id=MODEL,
19
- tokenizer_id=MODEL,
20
  api_key=_get_next_api_key(),
21
  generation_kwargs={
22
  "temperature": 0,
@@ -33,7 +31,7 @@ def get_custom_evaluator(prompt_template, structured_output, columns, is_sample)
33
  custom_evaluator = TextGeneration(
34
  llm=InferenceEndpointsLLM(
35
  model_id=MODEL,
36
- tokenizer_id=MODEL,
37
  api_key=_get_next_api_key(),
38
  structured_output={"format": "json", "schema": structured_output},
39
  generation_kwargs={
@@ -62,7 +60,8 @@ from distilabel.steps.tasks import UltraFeedback
62
  from distilabel.llms import InferenceEndpointsLLM
63
 
64
  MODEL = "{MODEL}"
65
- os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
 
66
 
67
  hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}[:{num_rows}]")
68
  data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
@@ -76,8 +75,8 @@ with Pipeline(name="ultrafeedback") as pipeline:
76
  ultrafeedback_evaluator = UltraFeedback(
77
  llm=InferenceEndpointsLLM(
78
  model_id=MODEL,
79
- tokenizer_id=MODEL,
80
- api_key=os.environ["HF_TOKEN"],
81
  generation_kwargs={{
82
  "temperature": 0,
83
  "max_new_tokens": 2048,
@@ -101,7 +100,8 @@ from distilabel.steps.tasks import UltraFeedback
101
  from distilabel.llms import InferenceEndpointsLLM
102
 
103
  MODEL = "{MODEL}"
104
- os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
 
105
 
106
  hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}")
107
  data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
@@ -119,8 +119,8 @@ with Pipeline(name="ultrafeedback") as pipeline:
119
  aspect=aspect,
120
  llm=InferenceEndpointsLLM(
121
  model_id=MODEL,
122
- tokenizer_id=MODEL,
123
- api_key=os.environ["HF_TOKEN"],
124
  generation_kwargs={{
125
  "temperature": 0,
126
  "max_new_tokens": 2048,
@@ -157,6 +157,7 @@ from distilabel.steps.tasks import TextGeneration
157
  from distilabel.llms import InferenceEndpointsLLM
158
 
159
  MODEL = "{MODEL}"
 
160
  CUSTOM_TEMPLATE = "{prompt_template}"
161
  os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
162
 
@@ -171,7 +172,7 @@ with Pipeline(name="custom-evaluation") as pipeline:
171
  custom_evaluator = TextGeneration(
172
  llm=InferenceEndpointsLLM(
173
  model_id=MODEL,
174
- tokenizer_id=MODEL,
175
  api_key=os.environ["HF_TOKEN"],
176
  structured_output={{"format": "json", "schema": {structured_output}}},
177
  generation_kwargs={{
 
5
  UltraFeedback,
6
  )
7
 
8
+ from distilabel_dataset_generator.constants import BASE_URL, MODEL
9
+ from distilabel_dataset_generator.pipelines.base import _get_next_api_key
10
+ from distilabel_dataset_generator.utils import extract_column_names
 
 
11
 
12
 
13
  def get_ultrafeedback_evaluator(aspect, is_sample):
14
  ultrafeedback_evaluator = UltraFeedback(
15
  llm=InferenceEndpointsLLM(
16
  model_id=MODEL,
17
+ base_url=BASE_URL,
18
  api_key=_get_next_api_key(),
19
  generation_kwargs={
20
  "temperature": 0,
 
31
  custom_evaluator = TextGeneration(
32
  llm=InferenceEndpointsLLM(
33
  model_id=MODEL,
34
+ base_url=BASE_URL,
35
  api_key=_get_next_api_key(),
36
  structured_output={"format": "json", "schema": structured_output},
37
  generation_kwargs={
 
60
  from distilabel.llms import InferenceEndpointsLLM
61
 
62
  MODEL = "{MODEL}"
63
+ BASE_URL = "{BASE_URL}"
64
+ os.environ["API_KEY"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
65
 
66
  hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}[:{num_rows}]")
67
  data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
 
75
  ultrafeedback_evaluator = UltraFeedback(
76
  llm=InferenceEndpointsLLM(
77
  model_id=MODEL,
78
+ base_url=BASE_URL,
79
+ api_key=os.environ["API_KEY"],
80
  generation_kwargs={{
81
  "temperature": 0,
82
  "max_new_tokens": 2048,
 
100
  from distilabel.llms import InferenceEndpointsLLM
101
 
102
  MODEL = "{MODEL}"
103
+ BASE_URL = "{BASE_URL}"
104
+ os.environ["BASE_URL"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
105
 
106
  hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}")
107
  data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
 
119
  aspect=aspect,
120
  llm=InferenceEndpointsLLM(
121
  model_id=MODEL,
122
+ base_url=BASE_URL,
123
+ api_key=os.environ["BASE_URL"],
124
  generation_kwargs={{
125
  "temperature": 0,
126
  "max_new_tokens": 2048,
 
157
  from distilabel.llms import InferenceEndpointsLLM
158
 
159
  MODEL = "{MODEL}"
160
+ BASE_URL = "{BASE_URL}"
161
  CUSTOM_TEMPLATE = "{prompt_template}"
162
  os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
163
 
 
172
  custom_evaluator = TextGeneration(
173
  llm=InferenceEndpointsLLM(
174
  model_id=MODEL,
175
+ base_url=BASE_URL,
176
  api_key=os.environ["HF_TOKEN"],
177
  structured_output={{"format": "json", "schema": {structured_output}}},
178
  generation_kwargs={{
src/distilabel_dataset_generator/pipelines/sft.py CHANGED
@@ -1,10 +1,12 @@
1
  from distilabel.llms import InferenceEndpointsLLM
2
  from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
3
 
4
- from src.distilabel_dataset_generator.pipelines.base import (
 
 
5
  MODEL,
6
- _get_next_api_key,
7
  )
 
8
 
9
  INFORMATION_SEEKING_PROMPT = (
10
  "You are an AI assistant designed to provide accurate and concise information on a wide"
@@ -144,6 +146,7 @@ def get_prompt_generator(temperature):
144
  api_key=_get_next_api_key(),
145
  model_id=MODEL,
146
  tokenizer_id=MODEL,
 
147
  generation_kwargs={
148
  "temperature": temperature,
149
  "max_new_tokens": 2048,
@@ -165,8 +168,9 @@ def get_magpie_generator(system_prompt, num_turns, is_sample):
165
  llm=InferenceEndpointsLLM(
166
  model_id=MODEL,
167
  tokenizer_id=MODEL,
 
168
  api_key=_get_next_api_key(),
169
- magpie_pre_query_template="llama3",
170
  generation_kwargs={
171
  "temperature": 0.9,
172
  "do_sample": True,
@@ -184,8 +188,9 @@ def get_magpie_generator(system_prompt, num_turns, is_sample):
184
  llm=InferenceEndpointsLLM(
185
  model_id=MODEL,
186
  tokenizer_id=MODEL,
 
187
  api_key=_get_next_api_key(),
188
- magpie_pre_query_template="llama3",
189
  generation_kwargs={
190
  "temperature": 0.9,
191
  "do_sample": True,
@@ -208,6 +213,7 @@ def get_response_generator(system_prompt, num_turns, is_sample):
208
  llm=InferenceEndpointsLLM(
209
  model_id=MODEL,
210
  tokenizer_id=MODEL,
 
211
  api_key=_get_next_api_key(),
212
  generation_kwargs={
213
  "temperature": 0.8,
@@ -223,6 +229,7 @@ def get_response_generator(system_prompt, num_turns, is_sample):
223
  llm=InferenceEndpointsLLM(
224
  model_id=MODEL,
225
  tokenizer_id=MODEL,
 
226
  api_key=_get_next_api_key(),
227
  generation_kwargs={
228
  "temperature": 0.8,
@@ -247,14 +254,16 @@ from distilabel.steps.tasks import MagpieGenerator
247
  from distilabel.llms import InferenceEndpointsLLM
248
 
249
  MODEL = "{MODEL}"
 
250
  SYSTEM_PROMPT = "{system_prompt}"
251
- os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
252
 
253
  with Pipeline(name="sft") as pipeline:
254
  magpie = MagpieGenerator(
255
  llm=InferenceEndpointsLLM(
256
  model_id=MODEL,
257
  tokenizer_id=MODEL,
 
258
  magpie_pre_query_template="llama3",
259
  generation_kwargs={{
260
  "temperature": 0.9,
@@ -262,7 +271,7 @@ with Pipeline(name="sft") as pipeline:
262
  "max_new_tokens": 2048,
263
  "stop_sequences": {_STOP_SEQUENCES}
264
  }},
265
- api_key=os.environ["HF_TOKEN"],
266
  ),
267
  n_turns={num_turns},
268
  num_rows={num_rows},
 
1
  from distilabel.llms import InferenceEndpointsLLM
2
  from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
3
 
4
+ from distilabel_dataset_generator.constants import (
5
+ BASE_URL,
6
+ MAGPIE_PRE_QUERY_TEMPLATE,
7
  MODEL,
 
8
  )
9
+ from distilabel_dataset_generator.pipelines.base import _get_next_api_key
10
 
11
  INFORMATION_SEEKING_PROMPT = (
12
  "You are an AI assistant designed to provide accurate and concise information on a wide"
 
146
  api_key=_get_next_api_key(),
147
  model_id=MODEL,
148
  tokenizer_id=MODEL,
149
+ base_url=BASE_URL,
150
  generation_kwargs={
151
  "temperature": temperature,
152
  "max_new_tokens": 2048,
 
168
  llm=InferenceEndpointsLLM(
169
  model_id=MODEL,
170
  tokenizer_id=MODEL,
171
+ base_url=BASE_URL,
172
  api_key=_get_next_api_key(),
173
+ magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
174
  generation_kwargs={
175
  "temperature": 0.9,
176
  "do_sample": True,
 
188
  llm=InferenceEndpointsLLM(
189
  model_id=MODEL,
190
  tokenizer_id=MODEL,
191
+ base_url=BASE_URL,
192
  api_key=_get_next_api_key(),
193
+ magpie_pre_query_template=MAGPIE_PRE_QUERY_TEMPLATE,
194
  generation_kwargs={
195
  "temperature": 0.9,
196
  "do_sample": True,
 
213
  llm=InferenceEndpointsLLM(
214
  model_id=MODEL,
215
  tokenizer_id=MODEL,
216
+ base_url=BASE_URL,
217
  api_key=_get_next_api_key(),
218
  generation_kwargs={
219
  "temperature": 0.8,
 
229
  llm=InferenceEndpointsLLM(
230
  model_id=MODEL,
231
  tokenizer_id=MODEL,
232
+ base_url=BASE_URL,
233
  api_key=_get_next_api_key(),
234
  generation_kwargs={
235
  "temperature": 0.8,
 
254
  from distilabel.llms import InferenceEndpointsLLM
255
 
256
  MODEL = "{MODEL}"
257
+ BASE_URL = "{BASE_URL}"
258
  SYSTEM_PROMPT = "{system_prompt}"
259
+ os.environ["API_KEY"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
260
 
261
  with Pipeline(name="sft") as pipeline:
262
  magpie = MagpieGenerator(
263
  llm=InferenceEndpointsLLM(
264
  model_id=MODEL,
265
  tokenizer_id=MODEL,
266
+ base_url=BASE_URL,
267
  magpie_pre_query_template="llama3",
268
  generation_kwargs={{
269
  "temperature": 0.9,
 
271
  "max_new_tokens": 2048,
272
  "stop_sequences": {_STOP_SEQUENCES}
273
  }},
274
+ api_key=os.environ["BASE_URL"],
275
  ),
276
  n_turns={num_turns},
277
  num_rows={num_rows},
src/distilabel_dataset_generator/pipelines/textcat.py CHANGED
@@ -1,5 +1,4 @@
1
  import random
2
- from pydantic import BaseModel, Field
3
  from typing import List
4
 
5
  from distilabel.llms import InferenceEndpointsLLM
@@ -8,12 +7,11 @@ from distilabel.steps.tasks import (
8
  TextClassification,
9
  TextGeneration,
10
  )
 
11
 
12
- from src.distilabel_dataset_generator.pipelines.base import (
13
- MODEL,
14
- _get_next_api_key,
15
- )
16
- from src.distilabel_dataset_generator.utils import get_preprocess_labels
17
 
18
  PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
19
 
@@ -73,7 +71,7 @@ def get_prompt_generator(temperature):
73
  llm=InferenceEndpointsLLM(
74
  api_key=_get_next_api_key(),
75
  model_id=MODEL,
76
- tokenizer_id=MODEL,
77
  structured_output={"format": "json", "schema": TextClassificationTask},
78
  generation_kwargs={
79
  "temperature": temperature,
@@ -92,7 +90,7 @@ def get_textcat_generator(difficulty, clarity, is_sample):
92
  textcat_generator = GenerateTextClassificationData(
93
  llm=InferenceEndpointsLLM(
94
  model_id=MODEL,
95
- tokenizer_id=MODEL,
96
  api_key=_get_next_api_key(),
97
  generation_kwargs={
98
  "temperature": 0.9,
@@ -114,7 +112,7 @@ def get_labeller_generator(system_prompt, labels, num_labels):
114
  labeller_generator = TextClassification(
115
  llm=InferenceEndpointsLLM(
116
  model_id=MODEL,
117
- tokenizer_id=MODEL,
118
  api_key=_get_next_api_key(),
119
  generation_kwargs={
120
  "temperature": 0.7,
@@ -149,8 +147,9 @@ from distilabel.steps import LoadDataFromDicts, KeepColumns
149
  from distilabel.steps.tasks import {"GenerateTextClassificationData" if num_labels == 1 else "GenerateTextClassificationData, TextClassification"}
150
 
151
  MODEL = "{MODEL}"
 
152
  TEXT_CLASSIFICATION_TASK = "{system_prompt}"
153
- os.environ["HF_TOKEN"] = (
154
  "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
155
  )
156
 
@@ -161,8 +160,8 @@ with Pipeline(name="textcat") as pipeline:
161
  textcat_generation = GenerateTextClassificationData(
162
  llm=InferenceEndpointsLLM(
163
  model_id=MODEL,
164
- tokenizer_id=MODEL,
165
- api_key=os.environ["HF_TOKEN"],
166
  generation_kwargs={{
167
  "temperature": 0.8,
168
  "max_new_tokens": 2048,
@@ -205,8 +204,8 @@ with Pipeline(name="textcat") as pipeline:
205
  textcat_labeller = TextClassification(
206
  llm=InferenceEndpointsLLM(
207
  model_id=MODEL,
208
- tokenizer_id=MODEL,
209
- api_key=os.environ["HF_TOKEN"],
210
  generation_kwargs={{
211
  "temperature": 0.8,
212
  "max_new_tokens": 2048,
 
1
  import random
 
2
  from typing import List
3
 
4
  from distilabel.llms import InferenceEndpointsLLM
 
7
  TextClassification,
8
  TextGeneration,
9
  )
10
+ from pydantic import BaseModel, Field
11
 
12
+ from distilabel_dataset_generator.constants import BASE_URL, MODEL
13
+ from distilabel_dataset_generator.pipelines.base import _get_next_api_key
14
+ from distilabel_dataset_generator.utils import get_preprocess_labels
 
 
15
 
16
  PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating very precise text classification tasks for dataset creation.
17
 
 
71
  llm=InferenceEndpointsLLM(
72
  api_key=_get_next_api_key(),
73
  model_id=MODEL,
74
+ base_url=BASE_URL,
75
  structured_output={"format": "json", "schema": TextClassificationTask},
76
  generation_kwargs={
77
  "temperature": temperature,
 
90
  textcat_generator = GenerateTextClassificationData(
91
  llm=InferenceEndpointsLLM(
92
  model_id=MODEL,
93
+ base_url=BASE_URL,
94
  api_key=_get_next_api_key(),
95
  generation_kwargs={
96
  "temperature": 0.9,
 
112
  labeller_generator = TextClassification(
113
  llm=InferenceEndpointsLLM(
114
  model_id=MODEL,
115
+ base_url=BASE_URL,
116
  api_key=_get_next_api_key(),
117
  generation_kwargs={
118
  "temperature": 0.7,
 
147
  from distilabel.steps.tasks import {"GenerateTextClassificationData" if num_labels == 1 else "GenerateTextClassificationData, TextClassification"}
148
 
149
  MODEL = "{MODEL}"
150
+ BASE_URL = "{BASE_URL}"
151
  TEXT_CLASSIFICATION_TASK = "{system_prompt}"
152
+ os.environ["API_KEY"] = (
153
  "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
154
  )
155
 
 
160
  textcat_generation = GenerateTextClassificationData(
161
  llm=InferenceEndpointsLLM(
162
  model_id=MODEL,
163
+ base_url=BASE_URL,
164
+ api_key=os.environ["API_KEY"],
165
  generation_kwargs={{
166
  "temperature": 0.8,
167
  "max_new_tokens": 2048,
 
204
  textcat_labeller = TextClassification(
205
  llm=InferenceEndpointsLLM(
206
  model_id=MODEL,
207
+ base_url=BASE_URL,
208
+ api_key=os.environ["API_KEY"],
209
  generation_kwargs={{
210
  "temperature": 0.8,
211
  "max_new_tokens": 2048,
src/distilabel_dataset_generator/utils.py CHANGED
@@ -15,7 +15,7 @@ from gradio.oauth import (
15
  from huggingface_hub import whoami
16
  from jinja2 import Environment, meta
17
 
18
- from src.distilabel_dataset_generator import argilla_client
19
 
20
  _LOGGED_OUT_CSS = ".main_ui_logged_out{opacity: 0.3; pointer-events: none}"
21
 
 
15
  from huggingface_hub import whoami
16
  from jinja2 import Environment, meta
17
 
18
+ from distilabel_dataset_generator.constants import argilla_client
19
 
20
  _LOGGED_OUT_CSS = ".main_ui_logged_out{opacity: 0.3; pointer-events: none}"
21