import re from typing import List, Union import argilla as rg import gradio as gr import pandas as pd from datasets import Dataset from huggingface_hub import HfApi from src.distilabel_dataset_generator.apps.base import ( get_argilla_client, get_main_ui, get_pipeline_code_ui, hide_success_message, push_pipeline_code_to_hub, show_success_message_argilla, show_success_message_hub, validate_argilla_user_workspace_dataset, ) from src.distilabel_dataset_generator.apps.base import ( push_dataset_to_hub as push_to_hub_base, ) from src.distilabel_dataset_generator.pipelines.base import ( DEFAULT_BATCH_SIZE, ) from src.distilabel_dataset_generator.pipelines.embeddings import ( get_embeddings, get_sentence_embedding_dimensions, ) from src.distilabel_dataset_generator.pipelines.textcat import ( DEFAULT_DATASET_DESCRIPTIONS, DEFAULT_DATASETS, DEFAULT_SYSTEM_PROMPTS, PROMPT_CREATION_PROMPT, generate_pipeline_code, get_labeller_generator, get_prompt_generator, get_textcat_generator, ) from src.distilabel_dataset_generator.utils import get_preprocess_labels TASK = "text_classification" def push_dataset_to_hub( dataframe: pd.DataFrame, private: bool = True, org_name: str = None, repo_name: str = None, oauth_token: Union[gr.OAuthToken, None] = None, progress=gr.Progress(), labels: List[str] = None, num_labels: int = 1, ): original_dataframe = dataframe.copy(deep=True) labels = get_preprocess_labels(labels) try: push_to_hub_base( dataframe, private, org_name, repo_name, oauth_token, progress, labels, num_labels, task=TASK, ) except Exception as e: raise gr.Error(f"Error pushing dataset to the Hub: {e}") return original_dataframe def push_dataset_to_argilla( dataframe: pd.DataFrame, dataset_name: str, oauth_token: Union[gr.OAuthToken, None] = None, progress=gr.Progress(), num_labels: int = 1, labels: List[str] = None, ) -> pd.DataFrame: original_dataframe = dataframe.copy(deep=True) try: progress(0.1, desc="Setting up user and workspace") client = get_argilla_client() hf_user = HfApi().whoami(token=oauth_token.token)["name"] labels = get_preprocess_labels(labels) settings = rg.Settings( fields=[ rg.TextField( name="text", description="The text classification data", title="Text", ), ], questions=[ ( rg.LabelQuestion( name="label", title="Label", description="The label of the text", labels=labels, ) if num_labels == 1 else rg.MultiLabelQuestion( name="labels", title="Labels", description="The labels of the conversation", labels=labels, ) ), ], metadata=[ rg.IntegerMetadataProperty(name="text_length", title="Text Length"), ], vectors=[ rg.VectorField( name="text_embeddings", dimensions=get_sentence_embedding_dimensions(), ) ], guidelines="Please review the text and provide or correct the label where needed.", ) dataframe["text_length"] = dataframe["text"].apply(len) dataframe["text_embeddings"] = get_embeddings(dataframe["text"]) progress(0.5, desc="Creating dataset") rg_dataset = client.datasets(name=dataset_name, workspace=hf_user) if rg_dataset is None: rg_dataset = rg.Dataset( name=dataset_name, workspace=hf_user, settings=settings, client=client, ) rg_dataset = rg_dataset.create() progress(0.7, desc="Pushing dataset to Argilla") hf_dataset = Dataset.from_pandas(dataframe) records = [ rg.Record( fields={ "text": sample["text"], }, metadata={"text_length": sample["text_length"]}, vectors={"text_embeddings": sample["text_embeddings"]}, suggestions=( [ rg.Suggestion( question_name="label" if num_labels == 1 else "labels", value=( sample["label"] if num_labels == 1 else sample["labels"] ), ) ] if ( (num_labels == 1 and sample["label"] in labels) or ( num_labels > 1 and all(label in labels for label in sample["labels"]) ) ) else [] ), ) for sample in hf_dataset ] rg_dataset.records.log(records=records) progress(1.0, desc="Dataset pushed to Argilla") except Exception as e: raise gr.Error(f"Error pushing dataset to Argilla: {e}") return original_dataframe def generate_system_prompt(dataset_description, progress=gr.Progress()): progress(0.0, desc="Generating text classification task") if dataset_description in DEFAULT_DATASET_DESCRIPTIONS: index = DEFAULT_DATASET_DESCRIPTIONS.index(dataset_description) if index < len(DEFAULT_SYSTEM_PROMPTS): return DEFAULT_SYSTEM_PROMPTS[index] progress(0.3, desc="Initializing text generation") generate_description = get_prompt_generator() progress(0.7, desc="Generating text classification task") result = next( generate_description.process( [ { "system_prompt": PROMPT_CREATION_PROMPT, "instruction": dataset_description, } ] ) )[0]["generation"] progress(1.0, desc="Text classification task generated") return result def generate_dataset( system_prompt: str, difficulty: str, clarity: str, labels: List[str] = None, num_labels: int = 1, num_rows: int = 10, is_sample: bool = False, progress=gr.Progress(), ) -> pd.DataFrame: progress(0.0, desc="(1/2) Generating text classification data") labels = get_preprocess_labels(labels) textcat_generator = get_textcat_generator( difficulty=difficulty, clarity=clarity, is_sample=is_sample ) labeller_generator = get_labeller_generator( system_prompt=system_prompt, labels=labels, num_labels=num_labels, ) total_steps: int = num_rows * 2 batch_size = DEFAULT_BATCH_SIZE # create text classification data n_processed = 0 textcat_results = [] while n_processed < num_rows: progress( 0.5 * n_processed / num_rows, total=total_steps, desc="(1/2) Generating text classification data", ) remaining_rows = num_rows - n_processed batch_size = min(batch_size, remaining_rows) inputs = [{"task": system_prompt} for _ in range(batch_size)] batch = list(textcat_generator.process(inputs=inputs)) textcat_results.extend(batch[0]) n_processed += batch_size for result in textcat_results: result["text"] = result["input_text"] # label text classification data progress(0.5, desc="(1/2) Generating text classification data") if not is_sample: n_processed = 0 labeller_results = [] while n_processed < num_rows: progress( 0.5 + 0.5 * n_processed / num_rows, total=total_steps, desc="(1/2) Labeling text classification data", ) batch = textcat_results[n_processed : n_processed + batch_size] labels_batch = list(labeller_generator.process(inputs=batch)) labeller_results.extend(labels_batch[0]) n_processed += batch_size progress( 1, total=total_steps, desc="(2/2) Creating dataset", ) # create final dataset distiset_results = [] source_results = textcat_results if is_sample else labeller_results for result in source_results: record = { key: result[key] for key in ["text", "label" if is_sample else "labels"] if key in result } distiset_results.append(record) dataframe = pd.DataFrame(distiset_results) if not is_sample: if num_labels == 1: dataframe = dataframe.rename(columns={"labels": "label"}) dataframe["label"] = dataframe["label"].apply( lambda x: x.lower().strip() if x.lower().strip() in labels else None ) else: dataframe["labels"] = dataframe["labels"].apply( lambda x: ( [ label.lower().strip() for label in x if label.lower().strip() in labels ] if isinstance(x, list) else None ) ) progress(1.0, desc="Dataset generation completed") return dataframe def update_suggested_labels(system_prompt): new_labels = re.findall(r"'(\b[\w-]+\b)'", system_prompt) if not new_labels: return gr.Warning( "No labels found in the system prompt. Please add labels manually." ) return gr.update(choices=new_labels, value=new_labels) def validate_input_labels(labels): if not labels or len(labels) < 2: raise gr.Error( f"Please select at least 2 labels to classify your text. You selected {len(labels) if labels else 0}." ) return labels def update_max_num_labels(labels): return gr.update(maximum=len(labels) if labels else 1) ( app, main_ui, custom_input_ui, dataset_description, examples, btn_generate_system_prompt, system_prompt, sample_dataset, btn_generate_sample_dataset, dataset_name, add_to_existing_dataset, btn_generate_full_dataset_argilla, btn_generate_and_push_to_argilla, btn_push_to_argilla, org_name, repo_name, private, btn_generate_full_dataset, btn_generate_and_push_to_hub, btn_push_to_hub, final_dataset, success_message, ) = get_main_ui( default_dataset_descriptions=DEFAULT_DATASET_DESCRIPTIONS, default_system_prompts=DEFAULT_SYSTEM_PROMPTS, default_datasets=DEFAULT_DATASETS, fn_generate_system_prompt=generate_system_prompt, fn_generate_dataset=generate_dataset, task=TASK, ) with app: with main_ui: with custom_input_ui: difficulty = gr.Dropdown( choices=[ ("High School", "high school"), ("College", "college"), ("PhD", "PhD"), ("Mixed", "mixed"), ], value="mixed", label="Difficulty", info="Select the comprehension level for the text. Ensure it matches the task context.", ) clarity = gr.Dropdown( choices=[ ("Clear", "clear"), ( "Understandable", "understandable with some effort", ), ("Ambiguous", "ambiguous"), ("Mixed", "mixed"), ], value="mixed", label="Clarity", info="Set how easily the correct label or labels can be identified.", ) with gr.Column(): labels = gr.Dropdown( choices=[], allow_custom_value=True, interactive=True, label="Labels", multiselect=True, info="Add the labels to classify the text.", ) with gr.Blocks(): btn_suggested_labels = gr.Button( value="Add suggested labels", size="sm", ) num_labels = gr.Number( label="Number of labels per text", value=1, minimum=1, maximum=10, info="Select 1 for single-label and >1 for multi-label.", ) num_rows = gr.Number( label="Number of rows", value=10, minimum=1, maximum=500, info="Select the number of rows in the dataset. More rows will take more time.", ) pipeline_code = get_pipeline_code_ui( generate_pipeline_code( system_prompt.value, difficulty=difficulty.value, clarity=clarity.value, labels=labels.value, num_labels=num_labels.value, num_rows=num_rows.value, ) ) # define app triggers btn_suggested_labels.click( fn=update_suggested_labels, inputs=[system_prompt], outputs=labels, ).then( fn=update_max_num_labels, inputs=[labels], outputs=[num_labels], ) gr.on( triggers=[ btn_generate_full_dataset.click, btn_generate_full_dataset_argilla.click, ], fn=hide_success_message, outputs=[success_message], ).then( fn=validate_input_labels, inputs=[labels], outputs=[labels], ).success( fn=generate_dataset, inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows], outputs=[final_dataset], show_progress=True, ) btn_generate_and_push_to_argilla.click( fn=validate_argilla_user_workspace_dataset, inputs=[dataset_name, final_dataset, add_to_existing_dataset], outputs=[final_dataset], show_progress=True, ).success( fn=hide_success_message, outputs=[success_message], ).success( fn=generate_dataset, inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows], outputs=[final_dataset], show_progress=True, ).success( fn=push_dataset_to_argilla, inputs=[final_dataset, dataset_name, num_labels, labels], outputs=[final_dataset], show_progress=True, ).success( fn=show_success_message_argilla, inputs=[], outputs=[success_message], ) btn_generate_and_push_to_hub.click( fn=hide_success_message, outputs=[success_message], ).then( fn=generate_dataset, inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows], outputs=[final_dataset], show_progress=True, ).then( fn=push_dataset_to_hub, inputs=[final_dataset, private, org_name, repo_name, labels, num_labels], outputs=[final_dataset], show_progress=True, ).then( fn=push_pipeline_code_to_hub, inputs=[pipeline_code, org_name, repo_name], outputs=[], show_progress=True, ).success( fn=show_success_message_hub, inputs=[org_name, repo_name], outputs=[success_message], ) btn_push_to_hub.click( fn=hide_success_message, outputs=[success_message], ).then( fn=push_dataset_to_hub, inputs=[final_dataset, private, org_name, repo_name, labels, num_labels], outputs=[final_dataset], show_progress=True, ).then( fn=push_pipeline_code_to_hub, inputs=[pipeline_code, org_name, repo_name], outputs=[], show_progress=True, ).success( fn=show_success_message_hub, inputs=[org_name, repo_name], outputs=[success_message], ) btn_push_to_argilla.click( fn=hide_success_message, outputs=[success_message], ).success( fn=validate_argilla_user_workspace_dataset, inputs=[dataset_name, final_dataset, add_to_existing_dataset], outputs=[final_dataset], show_progress=True, ).success( fn=push_dataset_to_argilla, inputs=[final_dataset, dataset_name, num_labels, labels], outputs=[final_dataset], show_progress=True, ).success( fn=show_success_message_argilla, inputs=[], outputs=[success_message], ) system_prompt.change( fn=generate_pipeline_code, inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows], outputs=[pipeline_code], ) difficulty.change( fn=generate_pipeline_code, inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows], outputs=[pipeline_code], ) clarity.change( fn=generate_pipeline_code, inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows], outputs=[pipeline_code], ) labels.change( fn=generate_pipeline_code, inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows], outputs=[pipeline_code], ).then( fn=update_max_num_labels, inputs=[labels], outputs=[num_labels], ) num_labels.change( fn=generate_pipeline_code, inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows], outputs=[pipeline_code], )