Spaces:
Runtime error
Runtime error
Merge pull request #8 from argilla-io/feat/add-multi-label
Browse files
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
@@ -69,14 +69,14 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
|
69 |
|
70 |
|
71 |
def generate_sample_dataset(
|
72 |
-
system_prompt, difficulty, clarity, labels,
|
73 |
):
|
74 |
dataframe = generate_dataset(
|
75 |
system_prompt=system_prompt,
|
76 |
difficulty=difficulty,
|
77 |
clarity=clarity,
|
78 |
labels=labels,
|
79 |
-
|
80 |
num_rows=10,
|
81 |
progress=progress,
|
82 |
is_sample=True,
|
@@ -89,7 +89,7 @@ def generate_dataset(
|
|
89 |
difficulty: str,
|
90 |
clarity: str,
|
91 |
labels: List[str] = None,
|
92 |
-
|
93 |
num_rows: int = 10,
|
94 |
temperature: float = 0.9,
|
95 |
is_sample: bool = False,
|
@@ -105,9 +105,9 @@ def generate_dataset(
|
|
105 |
is_sample=is_sample,
|
106 |
)
|
107 |
labeller_generator = get_labeller_generator(
|
108 |
-
system_prompt=f"{system_prompt} {', '.join(labels)}",
|
109 |
labels=labels,
|
110 |
-
|
111 |
)
|
112 |
total_steps: int = num_rows * 2
|
113 |
batch_size = DEFAULT_BATCH_SIZE
|
@@ -125,11 +125,16 @@ def generate_dataset(
|
|
125 |
batch_size = min(batch_size, remaining_rows)
|
126 |
inputs = []
|
127 |
for _ in range(batch_size):
|
128 |
-
if
|
129 |
-
num_labels =
|
|
|
|
|
|
|
|
|
130 |
else:
|
131 |
-
|
132 |
-
|
|
|
133 |
random.shuffle(sampled_labels)
|
134 |
inputs.append(
|
135 |
{
|
@@ -169,12 +174,7 @@ def generate_dataset(
|
|
169 |
distiset_results.append(record)
|
170 |
|
171 |
dataframe = pd.DataFrame(distiset_results)
|
172 |
-
if
|
173 |
-
dataframe = dataframe.rename(columns={"labels": "label"})
|
174 |
-
dataframe["label"] = dataframe["label"].apply(
|
175 |
-
lambda x: x.lower().strip() if x.lower().strip() in labels else None
|
176 |
-
)
|
177 |
-
else:
|
178 |
dataframe["labels"] = dataframe["labels"].apply(
|
179 |
lambda x: list(
|
180 |
set(
|
@@ -186,6 +186,12 @@ def generate_dataset(
|
|
186 |
)
|
187 |
)
|
188 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
progress(1.0, desc="Dataset created")
|
190 |
return dataframe
|
191 |
|
@@ -194,7 +200,7 @@ def push_dataset_to_hub(
|
|
194 |
dataframe: pd.DataFrame,
|
195 |
org_name: str,
|
196 |
repo_name: str,
|
197 |
-
|
198 |
labels: List[str] = None,
|
199 |
oauth_token: Union[gr.OAuthToken, None] = None,
|
200 |
private: bool = False,
|
@@ -206,18 +212,17 @@ def push_dataset_to_hub(
|
|
206 |
progress(0.3, desc="Preprocessing")
|
207 |
labels = get_preprocess_labels(labels)
|
208 |
progress(0.7, desc="Creating dataset")
|
209 |
-
if
|
210 |
-
dataframe["label"] = dataframe["label"].replace("", None)
|
211 |
-
features = Features(
|
212 |
-
{"text": Value("string"), "label": ClassLabel(names=labels)}
|
213 |
-
)
|
214 |
-
else:
|
215 |
features = Features(
|
216 |
{
|
217 |
"text": Value("string"),
|
218 |
"labels": Sequence(feature=ClassLabel(names=labels)),
|
219 |
}
|
220 |
)
|
|
|
|
|
|
|
|
|
221 |
dataset = Dataset.from_pandas(dataframe, features=features)
|
222 |
dataset = combine_datasets(repo_id, dataset)
|
223 |
distiset = Distiset({"default": dataset})
|
@@ -239,7 +244,7 @@ def push_dataset(
|
|
239 |
system_prompt: str,
|
240 |
difficulty: str,
|
241 |
clarity: str,
|
242 |
-
|
243 |
num_rows: int = 10,
|
244 |
labels: List[str] = None,
|
245 |
private: bool = False,
|
@@ -252,7 +257,7 @@ def push_dataset(
|
|
252 |
system_prompt=system_prompt,
|
253 |
difficulty=difficulty,
|
254 |
clarity=clarity,
|
255 |
-
|
256 |
labels=labels,
|
257 |
num_rows=num_rows,
|
258 |
temperature=temperature,
|
@@ -261,7 +266,7 @@ def push_dataset(
|
|
261 |
dataframe,
|
262 |
org_name,
|
263 |
repo_name,
|
264 |
-
|
265 |
labels,
|
266 |
oauth_token,
|
267 |
private,
|
@@ -288,19 +293,19 @@ def push_dataset(
|
|
288 |
],
|
289 |
questions=[
|
290 |
(
|
291 |
-
rg.
|
292 |
-
name="label",
|
293 |
-
title="Label",
|
294 |
-
description="The label of the text",
|
295 |
-
labels=labels,
|
296 |
-
)
|
297 |
-
if num_labels == 1
|
298 |
-
else rg.MultiLabelQuestion(
|
299 |
name="labels",
|
300 |
title="Labels",
|
301 |
description="The labels of the conversation",
|
302 |
labels=labels,
|
303 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
),
|
305 |
],
|
306 |
metadata=[
|
@@ -340,16 +345,16 @@ def push_dataset(
|
|
340 |
suggestions=(
|
341 |
[
|
342 |
rg.Suggestion(
|
343 |
-
question_name="
|
344 |
value=(
|
345 |
-
sample["
|
346 |
),
|
347 |
)
|
348 |
]
|
349 |
if (
|
350 |
-
(
|
351 |
or (
|
352 |
-
|
353 |
and all(label in labels for label in sample["labels"])
|
354 |
)
|
355 |
)
|
@@ -373,10 +378,6 @@ def validate_input_labels(labels):
|
|
373 |
return labels
|
374 |
|
375 |
|
376 |
-
def update_max_num_labels(labels):
|
377 |
-
return gr.update(maximum=len(labels) if labels else 1)
|
378 |
-
|
379 |
-
|
380 |
def show_pipeline_code_visibility():
|
381 |
return {pipeline_code_ui: gr.Accordion(visible=True)}
|
382 |
|
@@ -434,13 +435,11 @@ with gr.Blocks() as app:
|
|
434 |
multiselect=True,
|
435 |
info="Add the labels to classify the text.",
|
436 |
)
|
437 |
-
|
438 |
-
label="
|
439 |
-
value=
|
440 |
-
minimum=1,
|
441 |
-
maximum=10,
|
442 |
-
info="Select 1 for single-label and >1 for multi-label.",
|
443 |
interactive=True,
|
|
|
444 |
)
|
445 |
clarity = gr.Dropdown(
|
446 |
choices=[
|
@@ -521,7 +520,7 @@ with gr.Blocks() as app:
|
|
521 |
difficulty=difficulty.value,
|
522 |
clarity=clarity.value,
|
523 |
labels=labels.value,
|
524 |
-
num_labels=
|
525 |
num_rows=num_rows.value,
|
526 |
temperature=temperature.value,
|
527 |
)
|
@@ -538,24 +537,14 @@ with gr.Blocks() as app:
|
|
538 |
show_progress=True,
|
539 |
).then(
|
540 |
fn=generate_sample_dataset,
|
541 |
-
inputs=[system_prompt, difficulty, clarity, labels,
|
542 |
outputs=[dataframe],
|
543 |
show_progress=True,
|
544 |
-
).then(
|
545 |
-
fn=update_max_num_labels,
|
546 |
-
inputs=[labels],
|
547 |
-
outputs=[num_labels],
|
548 |
-
)
|
549 |
-
|
550 |
-
labels.input(
|
551 |
-
fn=update_max_num_labels,
|
552 |
-
inputs=[labels],
|
553 |
-
outputs=[num_labels],
|
554 |
)
|
555 |
|
556 |
btn_apply_to_sample_dataset.click(
|
557 |
fn=generate_sample_dataset,
|
558 |
-
inputs=[system_prompt, difficulty, clarity, labels,
|
559 |
outputs=[dataframe],
|
560 |
show_progress=True,
|
561 |
)
|
@@ -586,7 +575,7 @@ with gr.Blocks() as app:
|
|
586 |
system_prompt,
|
587 |
difficulty,
|
588 |
clarity,
|
589 |
-
|
590 |
num_rows,
|
591 |
labels,
|
592 |
private,
|
@@ -606,7 +595,7 @@ with gr.Blocks() as app:
|
|
606 |
difficulty,
|
607 |
clarity,
|
608 |
labels,
|
609 |
-
|
610 |
num_rows,
|
611 |
temperature,
|
612 |
],
|
|
|
69 |
|
70 |
|
71 |
def generate_sample_dataset(
|
72 |
+
system_prompt, difficulty, clarity, labels, multi_label, progress=gr.Progress()
|
73 |
):
|
74 |
dataframe = generate_dataset(
|
75 |
system_prompt=system_prompt,
|
76 |
difficulty=difficulty,
|
77 |
clarity=clarity,
|
78 |
labels=labels,
|
79 |
+
multi_label=multi_label,
|
80 |
num_rows=10,
|
81 |
progress=progress,
|
82 |
is_sample=True,
|
|
|
89 |
difficulty: str,
|
90 |
clarity: str,
|
91 |
labels: List[str] = None,
|
92 |
+
multi_label: bool = False,
|
93 |
num_rows: int = 10,
|
94 |
temperature: float = 0.9,
|
95 |
is_sample: bool = False,
|
|
|
105 |
is_sample=is_sample,
|
106 |
)
|
107 |
labeller_generator = get_labeller_generator(
|
108 |
+
system_prompt=f"{system_prompt}. Potential labels: {', '.join(labels)}",
|
109 |
labels=labels,
|
110 |
+
multi_label=multi_label,
|
111 |
)
|
112 |
total_steps: int = num_rows * 2
|
113 |
batch_size = DEFAULT_BATCH_SIZE
|
|
|
125 |
batch_size = min(batch_size, remaining_rows)
|
126 |
inputs = []
|
127 |
for _ in range(batch_size):
|
128 |
+
if multi_label:
|
129 |
+
num_labels = len(labels)
|
130 |
+
k = int(
|
131 |
+
random.betavariate(alpha=(num_labels - 1), beta=num_labels)
|
132 |
+
* num_labels
|
133 |
+
)
|
134 |
else:
|
135 |
+
k = 1
|
136 |
+
|
137 |
+
sampled_labels = random.sample(labels, min(k, len(labels)))
|
138 |
random.shuffle(sampled_labels)
|
139 |
inputs.append(
|
140 |
{
|
|
|
174 |
distiset_results.append(record)
|
175 |
|
176 |
dataframe = pd.DataFrame(distiset_results)
|
177 |
+
if multi_label:
|
|
|
|
|
|
|
|
|
|
|
178 |
dataframe["labels"] = dataframe["labels"].apply(
|
179 |
lambda x: list(
|
180 |
set(
|
|
|
186 |
)
|
187 |
)
|
188 |
)
|
189 |
+
else:
|
190 |
+
dataframe = dataframe.rename(columns={"labels": "label"})
|
191 |
+
dataframe["label"] = dataframe["label"].apply(
|
192 |
+
lambda x: x.lower().strip() if x.lower().strip() in labels else None
|
193 |
+
)
|
194 |
+
|
195 |
progress(1.0, desc="Dataset created")
|
196 |
return dataframe
|
197 |
|
|
|
200 |
dataframe: pd.DataFrame,
|
201 |
org_name: str,
|
202 |
repo_name: str,
|
203 |
+
multi_label: bool = False,
|
204 |
labels: List[str] = None,
|
205 |
oauth_token: Union[gr.OAuthToken, None] = None,
|
206 |
private: bool = False,
|
|
|
212 |
progress(0.3, desc="Preprocessing")
|
213 |
labels = get_preprocess_labels(labels)
|
214 |
progress(0.7, desc="Creating dataset")
|
215 |
+
if multi_label:
|
|
|
|
|
|
|
|
|
|
|
216 |
features = Features(
|
217 |
{
|
218 |
"text": Value("string"),
|
219 |
"labels": Sequence(feature=ClassLabel(names=labels)),
|
220 |
}
|
221 |
)
|
222 |
+
else:
|
223 |
+
features = Features(
|
224 |
+
{"text": Value("string"), "label": ClassLabel(names=labels)}
|
225 |
+
)
|
226 |
dataset = Dataset.from_pandas(dataframe, features=features)
|
227 |
dataset = combine_datasets(repo_id, dataset)
|
228 |
distiset = Distiset({"default": dataset})
|
|
|
244 |
system_prompt: str,
|
245 |
difficulty: str,
|
246 |
clarity: str,
|
247 |
+
multi_label: int = 1,
|
248 |
num_rows: int = 10,
|
249 |
labels: List[str] = None,
|
250 |
private: bool = False,
|
|
|
257 |
system_prompt=system_prompt,
|
258 |
difficulty=difficulty,
|
259 |
clarity=clarity,
|
260 |
+
multi_label=multi_label,
|
261 |
labels=labels,
|
262 |
num_rows=num_rows,
|
263 |
temperature=temperature,
|
|
|
266 |
dataframe,
|
267 |
org_name,
|
268 |
repo_name,
|
269 |
+
multi_label,
|
270 |
labels,
|
271 |
oauth_token,
|
272 |
private,
|
|
|
293 |
],
|
294 |
questions=[
|
295 |
(
|
296 |
+
rg.MultiLabelQuestion(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
name="labels",
|
298 |
title="Labels",
|
299 |
description="The labels of the conversation",
|
300 |
labels=labels,
|
301 |
)
|
302 |
+
if multi_label
|
303 |
+
else rg.LabelQuestion(
|
304 |
+
name="label",
|
305 |
+
title="Label",
|
306 |
+
description="The label of the text",
|
307 |
+
labels=labels,
|
308 |
+
)
|
309 |
),
|
310 |
],
|
311 |
metadata=[
|
|
|
345 |
suggestions=(
|
346 |
[
|
347 |
rg.Suggestion(
|
348 |
+
question_name="labels" if multi_label else "label",
|
349 |
value=(
|
350 |
+
sample["labels"] if multi_label else sample["label"]
|
351 |
),
|
352 |
)
|
353 |
]
|
354 |
if (
|
355 |
+
(not multi_label and sample["label"] in labels)
|
356 |
or (
|
357 |
+
multi_label
|
358 |
and all(label in labels for label in sample["labels"])
|
359 |
)
|
360 |
)
|
|
|
378 |
return labels
|
379 |
|
380 |
|
|
|
|
|
|
|
|
|
381 |
def show_pipeline_code_visibility():
|
382 |
return {pipeline_code_ui: gr.Accordion(visible=True)}
|
383 |
|
|
|
435 |
multiselect=True,
|
436 |
info="Add the labels to classify the text.",
|
437 |
)
|
438 |
+
multi_label = gr.Checkbox(
|
439 |
+
label="Multi-label",
|
440 |
+
value=False,
|
|
|
|
|
|
|
441 |
interactive=True,
|
442 |
+
info="If checked, the text will be classified into multiple labels.",
|
443 |
)
|
444 |
clarity = gr.Dropdown(
|
445 |
choices=[
|
|
|
520 |
difficulty=difficulty.value,
|
521 |
clarity=clarity.value,
|
522 |
labels=labels.value,
|
523 |
+
num_labels=len(labels.value) if multi_label.value else 1,
|
524 |
num_rows=num_rows.value,
|
525 |
temperature=temperature.value,
|
526 |
)
|
|
|
537 |
show_progress=True,
|
538 |
).then(
|
539 |
fn=generate_sample_dataset,
|
540 |
+
inputs=[system_prompt, difficulty, clarity, labels, multi_label],
|
541 |
outputs=[dataframe],
|
542 |
show_progress=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
543 |
)
|
544 |
|
545 |
btn_apply_to_sample_dataset.click(
|
546 |
fn=generate_sample_dataset,
|
547 |
+
inputs=[system_prompt, difficulty, clarity, labels, multi_label],
|
548 |
outputs=[dataframe],
|
549 |
show_progress=True,
|
550 |
)
|
|
|
575 |
system_prompt,
|
576 |
difficulty,
|
577 |
clarity,
|
578 |
+
multi_label,
|
579 |
num_rows,
|
580 |
labels,
|
581 |
private,
|
|
|
595 |
difficulty,
|
596 |
clarity,
|
597 |
labels,
|
598 |
+
multi_label,
|
599 |
num_rows,
|
600 |
temperature,
|
601 |
],
|
src/synthetic_dataset_generator/pipelines/textcat.py
CHANGED
@@ -29,7 +29,7 @@ Description: DavidMovieHouse is a cinema that has been in business for 10 years.
|
|
29 |
Output: {"classification_task": "The company DavidMovieHouse is a cinema that has been in business for 10 years and has had customers reviews of varying customer groups. Classify the customer reviews as", "labels": ["positive", "negative"]}
|
30 |
|
31 |
Description: A dataset that focuses on creating neo-ludite discussions about technologies within the AI space.
|
32 |
-
Output: {"classification_task": "Neo-ludiite discussions about technologies within the AI space cover from different speaking people
|
33 |
|
34 |
Description: A dataset that covers the articles of a niche sports website called TheSportBlogs that focuses on female sports within the ballsport domain for the US market.
|
35 |
Output: {"classification_task": "TechSportBlogs is a niche sports website that focuses on female sports within the ballsport domain for the US market. Written by different journalists. Determine the category of based on the article using the following categories", "labels": ["basketball", "volleyball", "tennis", "hockey", "baseball", "soccer"]}
|
@@ -102,7 +102,7 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample):
|
|
102 |
return textcat_generator
|
103 |
|
104 |
|
105 |
-
def get_labeller_generator(system_prompt, labels,
|
106 |
labeller_generator = TextClassification(
|
107 |
llm=InferenceEndpointsLLM(
|
108 |
model_id=MODEL,
|
@@ -115,7 +115,7 @@ def get_labeller_generator(system_prompt, labels, num_labels):
|
|
115 |
),
|
116 |
context=system_prompt,
|
117 |
available_labels=labels,
|
118 |
-
n=
|
119 |
default_label="unknown",
|
120 |
)
|
121 |
labeller_generator.load()
|
|
|
29 |
Output: {"classification_task": "The company DavidMovieHouse is a cinema that has been in business for 10 years and has had customers reviews of varying customer groups. Classify the customer reviews as", "labels": ["positive", "negative"]}
|
30 |
|
31 |
Description: A dataset that focuses on creating neo-ludite discussions about technologies within the AI space.
|
32 |
+
Output: {"classification_task": "Neo-ludiite discussions about technologies within the AI space cover from different speaking people. Categorize the discussions into one of the following categories", "labels": ["tech-support", "tech-opposition"]}
|
33 |
|
34 |
Description: A dataset that covers the articles of a niche sports website called TheSportBlogs that focuses on female sports within the ballsport domain for the US market.
|
35 |
Output: {"classification_task": "TechSportBlogs is a niche sports website that focuses on female sports within the ballsport domain for the US market. Written by different journalists. Determine the category of based on the article using the following categories", "labels": ["basketball", "volleyball", "tennis", "hockey", "baseball", "soccer"]}
|
|
|
102 |
return textcat_generator
|
103 |
|
104 |
|
105 |
+
def get_labeller_generator(system_prompt, labels, multi_label):
|
106 |
labeller_generator = TextClassification(
|
107 |
llm=InferenceEndpointsLLM(
|
108 |
model_id=MODEL,
|
|
|
115 |
),
|
116 |
context=system_prompt,
|
117 |
available_labels=labels,
|
118 |
+
n=len(labels) if multi_label else 1,
|
119 |
default_label="unknown",
|
120 |
)
|
121 |
labeller_generator.load()
|