davidberenstein1957 HF staff commited on
Commit
10b52aa
Β·
unverified Β·
2 Parent(s): 7b7c1be 0c1d5b6

Merge pull request #8 from argilla-io/feat/add-multi-label

Browse files
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -69,14 +69,14 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
69
 
70
 
71
  def generate_sample_dataset(
72
- system_prompt, difficulty, clarity, labels, num_labels, progress=gr.Progress()
73
  ):
74
  dataframe = generate_dataset(
75
  system_prompt=system_prompt,
76
  difficulty=difficulty,
77
  clarity=clarity,
78
  labels=labels,
79
- num_labels=num_labels,
80
  num_rows=10,
81
  progress=progress,
82
  is_sample=True,
@@ -89,7 +89,7 @@ def generate_dataset(
89
  difficulty: str,
90
  clarity: str,
91
  labels: List[str] = None,
92
- num_labels: int = 1,
93
  num_rows: int = 10,
94
  temperature: float = 0.9,
95
  is_sample: bool = False,
@@ -105,9 +105,9 @@ def generate_dataset(
105
  is_sample=is_sample,
106
  )
107
  labeller_generator = get_labeller_generator(
108
- system_prompt=f"{system_prompt} {', '.join(labels)}",
109
  labels=labels,
110
- num_labels=num_labels,
111
  )
112
  total_steps: int = num_rows * 2
113
  batch_size = DEFAULT_BATCH_SIZE
@@ -125,11 +125,16 @@ def generate_dataset(
125
  batch_size = min(batch_size, remaining_rows)
126
  inputs = []
127
  for _ in range(batch_size):
128
- if num_labels == 1:
129
- num_labels = 1
 
 
 
 
130
  else:
131
- num_labels = int(random.gammavariate(2, 2) * num_labels)
132
- sampled_labels = random.sample(labels, num_labels)
 
133
  random.shuffle(sampled_labels)
134
  inputs.append(
135
  {
@@ -169,12 +174,7 @@ def generate_dataset(
169
  distiset_results.append(record)
170
 
171
  dataframe = pd.DataFrame(distiset_results)
172
- if num_labels == 1:
173
- dataframe = dataframe.rename(columns={"labels": "label"})
174
- dataframe["label"] = dataframe["label"].apply(
175
- lambda x: x.lower().strip() if x.lower().strip() in labels else None
176
- )
177
- else:
178
  dataframe["labels"] = dataframe["labels"].apply(
179
  lambda x: list(
180
  set(
@@ -186,6 +186,12 @@ def generate_dataset(
186
  )
187
  )
188
  )
 
 
 
 
 
 
189
  progress(1.0, desc="Dataset created")
190
  return dataframe
191
 
@@ -194,7 +200,7 @@ def push_dataset_to_hub(
194
  dataframe: pd.DataFrame,
195
  org_name: str,
196
  repo_name: str,
197
- num_labels: int = 1,
198
  labels: List[str] = None,
199
  oauth_token: Union[gr.OAuthToken, None] = None,
200
  private: bool = False,
@@ -206,18 +212,17 @@ def push_dataset_to_hub(
206
  progress(0.3, desc="Preprocessing")
207
  labels = get_preprocess_labels(labels)
208
  progress(0.7, desc="Creating dataset")
209
- if num_labels == 1:
210
- dataframe["label"] = dataframe["label"].replace("", None)
211
- features = Features(
212
- {"text": Value("string"), "label": ClassLabel(names=labels)}
213
- )
214
- else:
215
  features = Features(
216
  {
217
  "text": Value("string"),
218
  "labels": Sequence(feature=ClassLabel(names=labels)),
219
  }
220
  )
 
 
 
 
221
  dataset = Dataset.from_pandas(dataframe, features=features)
222
  dataset = combine_datasets(repo_id, dataset)
223
  distiset = Distiset({"default": dataset})
@@ -239,7 +244,7 @@ def push_dataset(
239
  system_prompt: str,
240
  difficulty: str,
241
  clarity: str,
242
- num_labels: int = 1,
243
  num_rows: int = 10,
244
  labels: List[str] = None,
245
  private: bool = False,
@@ -252,7 +257,7 @@ def push_dataset(
252
  system_prompt=system_prompt,
253
  difficulty=difficulty,
254
  clarity=clarity,
255
- num_labels=num_labels,
256
  labels=labels,
257
  num_rows=num_rows,
258
  temperature=temperature,
@@ -261,7 +266,7 @@ def push_dataset(
261
  dataframe,
262
  org_name,
263
  repo_name,
264
- num_labels,
265
  labels,
266
  oauth_token,
267
  private,
@@ -288,19 +293,19 @@ def push_dataset(
288
  ],
289
  questions=[
290
  (
291
- rg.LabelQuestion(
292
- name="label",
293
- title="Label",
294
- description="The label of the text",
295
- labels=labels,
296
- )
297
- if num_labels == 1
298
- else rg.MultiLabelQuestion(
299
  name="labels",
300
  title="Labels",
301
  description="The labels of the conversation",
302
  labels=labels,
303
  )
 
 
 
 
 
 
 
304
  ),
305
  ],
306
  metadata=[
@@ -340,16 +345,16 @@ def push_dataset(
340
  suggestions=(
341
  [
342
  rg.Suggestion(
343
- question_name="label" if num_labels == 1 else "labels",
344
  value=(
345
- sample["label"] if num_labels == 1 else sample["labels"]
346
  ),
347
  )
348
  ]
349
  if (
350
- (num_labels == 1 and sample["label"] in labels)
351
  or (
352
- num_labels > 1
353
  and all(label in labels for label in sample["labels"])
354
  )
355
  )
@@ -373,10 +378,6 @@ def validate_input_labels(labels):
373
  return labels
374
 
375
 
376
- def update_max_num_labels(labels):
377
- return gr.update(maximum=len(labels) if labels else 1)
378
-
379
-
380
  def show_pipeline_code_visibility():
381
  return {pipeline_code_ui: gr.Accordion(visible=True)}
382
 
@@ -434,13 +435,11 @@ with gr.Blocks() as app:
434
  multiselect=True,
435
  info="Add the labels to classify the text.",
436
  )
437
- num_labels = gr.Number(
438
- label="Number of labels per text",
439
- value=1,
440
- minimum=1,
441
- maximum=10,
442
- info="Select 1 for single-label and >1 for multi-label.",
443
  interactive=True,
 
444
  )
445
  clarity = gr.Dropdown(
446
  choices=[
@@ -521,7 +520,7 @@ with gr.Blocks() as app:
521
  difficulty=difficulty.value,
522
  clarity=clarity.value,
523
  labels=labels.value,
524
- num_labels=num_labels.value,
525
  num_rows=num_rows.value,
526
  temperature=temperature.value,
527
  )
@@ -538,24 +537,14 @@ with gr.Blocks() as app:
538
  show_progress=True,
539
  ).then(
540
  fn=generate_sample_dataset,
541
- inputs=[system_prompt, difficulty, clarity, labels, num_labels],
542
  outputs=[dataframe],
543
  show_progress=True,
544
- ).then(
545
- fn=update_max_num_labels,
546
- inputs=[labels],
547
- outputs=[num_labels],
548
- )
549
-
550
- labels.input(
551
- fn=update_max_num_labels,
552
- inputs=[labels],
553
- outputs=[num_labels],
554
  )
555
 
556
  btn_apply_to_sample_dataset.click(
557
  fn=generate_sample_dataset,
558
- inputs=[system_prompt, difficulty, clarity, labels, num_labels],
559
  outputs=[dataframe],
560
  show_progress=True,
561
  )
@@ -586,7 +575,7 @@ with gr.Blocks() as app:
586
  system_prompt,
587
  difficulty,
588
  clarity,
589
- num_labels,
590
  num_rows,
591
  labels,
592
  private,
@@ -606,7 +595,7 @@ with gr.Blocks() as app:
606
  difficulty,
607
  clarity,
608
  labels,
609
- num_labels,
610
  num_rows,
611
  temperature,
612
  ],
 
69
 
70
 
71
  def generate_sample_dataset(
72
+ system_prompt, difficulty, clarity, labels, multi_label, progress=gr.Progress()
73
  ):
74
  dataframe = generate_dataset(
75
  system_prompt=system_prompt,
76
  difficulty=difficulty,
77
  clarity=clarity,
78
  labels=labels,
79
+ multi_label=multi_label,
80
  num_rows=10,
81
  progress=progress,
82
  is_sample=True,
 
89
  difficulty: str,
90
  clarity: str,
91
  labels: List[str] = None,
92
+ multi_label: bool = False,
93
  num_rows: int = 10,
94
  temperature: float = 0.9,
95
  is_sample: bool = False,
 
105
  is_sample=is_sample,
106
  )
107
  labeller_generator = get_labeller_generator(
108
+ system_prompt=f"{system_prompt}. Potential labels: {', '.join(labels)}",
109
  labels=labels,
110
+ multi_label=multi_label,
111
  )
112
  total_steps: int = num_rows * 2
113
  batch_size = DEFAULT_BATCH_SIZE
 
125
  batch_size = min(batch_size, remaining_rows)
126
  inputs = []
127
  for _ in range(batch_size):
128
+ if multi_label:
129
+ num_labels = len(labels)
130
+ k = int(
131
+ random.betavariate(alpha=(num_labels - 1), beta=num_labels)
132
+ * num_labels
133
+ )
134
  else:
135
+ k = 1
136
+
137
+ sampled_labels = random.sample(labels, min(k, len(labels)))
138
  random.shuffle(sampled_labels)
139
  inputs.append(
140
  {
 
174
  distiset_results.append(record)
175
 
176
  dataframe = pd.DataFrame(distiset_results)
177
+ if multi_label:
 
 
 
 
 
178
  dataframe["labels"] = dataframe["labels"].apply(
179
  lambda x: list(
180
  set(
 
186
  )
187
  )
188
  )
189
+ else:
190
+ dataframe = dataframe.rename(columns={"labels": "label"})
191
+ dataframe["label"] = dataframe["label"].apply(
192
+ lambda x: x.lower().strip() if x.lower().strip() in labels else None
193
+ )
194
+
195
  progress(1.0, desc="Dataset created")
196
  return dataframe
197
 
 
200
  dataframe: pd.DataFrame,
201
  org_name: str,
202
  repo_name: str,
203
+ multi_label: bool = False,
204
  labels: List[str] = None,
205
  oauth_token: Union[gr.OAuthToken, None] = None,
206
  private: bool = False,
 
212
  progress(0.3, desc="Preprocessing")
213
  labels = get_preprocess_labels(labels)
214
  progress(0.7, desc="Creating dataset")
215
+ if multi_label:
 
 
 
 
 
216
  features = Features(
217
  {
218
  "text": Value("string"),
219
  "labels": Sequence(feature=ClassLabel(names=labels)),
220
  }
221
  )
222
+ else:
223
+ features = Features(
224
+ {"text": Value("string"), "label": ClassLabel(names=labels)}
225
+ )
226
  dataset = Dataset.from_pandas(dataframe, features=features)
227
  dataset = combine_datasets(repo_id, dataset)
228
  distiset = Distiset({"default": dataset})
 
244
  system_prompt: str,
245
  difficulty: str,
246
  clarity: str,
247
+ multi_label: int = 1,
248
  num_rows: int = 10,
249
  labels: List[str] = None,
250
  private: bool = False,
 
257
  system_prompt=system_prompt,
258
  difficulty=difficulty,
259
  clarity=clarity,
260
+ multi_label=multi_label,
261
  labels=labels,
262
  num_rows=num_rows,
263
  temperature=temperature,
 
266
  dataframe,
267
  org_name,
268
  repo_name,
269
+ multi_label,
270
  labels,
271
  oauth_token,
272
  private,
 
293
  ],
294
  questions=[
295
  (
296
+ rg.MultiLabelQuestion(
 
 
 
 
 
 
 
297
  name="labels",
298
  title="Labels",
299
  description="The labels of the conversation",
300
  labels=labels,
301
  )
302
+ if multi_label
303
+ else rg.LabelQuestion(
304
+ name="label",
305
+ title="Label",
306
+ description="The label of the text",
307
+ labels=labels,
308
+ )
309
  ),
310
  ],
311
  metadata=[
 
345
  suggestions=(
346
  [
347
  rg.Suggestion(
348
+ question_name="labels" if multi_label else "label",
349
  value=(
350
+ sample["labels"] if multi_label else sample["label"]
351
  ),
352
  )
353
  ]
354
  if (
355
+ (not multi_label and sample["label"] in labels)
356
  or (
357
+ multi_label
358
  and all(label in labels for label in sample["labels"])
359
  )
360
  )
 
378
  return labels
379
 
380
 
 
 
 
 
381
  def show_pipeline_code_visibility():
382
  return {pipeline_code_ui: gr.Accordion(visible=True)}
383
 
 
435
  multiselect=True,
436
  info="Add the labels to classify the text.",
437
  )
438
+ multi_label = gr.Checkbox(
439
+ label="Multi-label",
440
+ value=False,
 
 
 
441
  interactive=True,
442
+ info="If checked, the text will be classified into multiple labels.",
443
  )
444
  clarity = gr.Dropdown(
445
  choices=[
 
520
  difficulty=difficulty.value,
521
  clarity=clarity.value,
522
  labels=labels.value,
523
+ num_labels=len(labels.value) if multi_label.value else 1,
524
  num_rows=num_rows.value,
525
  temperature=temperature.value,
526
  )
 
537
  show_progress=True,
538
  ).then(
539
  fn=generate_sample_dataset,
540
+ inputs=[system_prompt, difficulty, clarity, labels, multi_label],
541
  outputs=[dataframe],
542
  show_progress=True,
 
 
 
 
 
 
 
 
 
 
543
  )
544
 
545
  btn_apply_to_sample_dataset.click(
546
  fn=generate_sample_dataset,
547
+ inputs=[system_prompt, difficulty, clarity, labels, multi_label],
548
  outputs=[dataframe],
549
  show_progress=True,
550
  )
 
575
  system_prompt,
576
  difficulty,
577
  clarity,
578
+ multi_label,
579
  num_rows,
580
  labels,
581
  private,
 
595
  difficulty,
596
  clarity,
597
  labels,
598
+ multi_label,
599
  num_rows,
600
  temperature,
601
  ],
src/synthetic_dataset_generator/pipelines/textcat.py CHANGED
@@ -29,7 +29,7 @@ Description: DavidMovieHouse is a cinema that has been in business for 10 years.
29
  Output: {"classification_task": "The company DavidMovieHouse is a cinema that has been in business for 10 years and has had customers reviews of varying customer groups. Classify the customer reviews as", "labels": ["positive", "negative"]}
30
 
31
  Description: A dataset that focuses on creating neo-ludite discussions about technologies within the AI space.
32
- Output: {"classification_task": "Neo-ludiite discussions about technologies within the AI space cover from different speaking people . Categorize the discussions into one of the following categories", "labels": ["tech-support", "tech-opposition"]}
33
 
34
  Description: A dataset that covers the articles of a niche sports website called TheSportBlogs that focuses on female sports within the ballsport domain for the US market.
35
  Output: {"classification_task": "TechSportBlogs is a niche sports website that focuses on female sports within the ballsport domain for the US market. Written by different journalists. Determine the category of based on the article using the following categories", "labels": ["basketball", "volleyball", "tennis", "hockey", "baseball", "soccer"]}
@@ -102,7 +102,7 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample):
102
  return textcat_generator
103
 
104
 
105
- def get_labeller_generator(system_prompt, labels, num_labels):
106
  labeller_generator = TextClassification(
107
  llm=InferenceEndpointsLLM(
108
  model_id=MODEL,
@@ -115,7 +115,7 @@ def get_labeller_generator(system_prompt, labels, num_labels):
115
  ),
116
  context=system_prompt,
117
  available_labels=labels,
118
- n=num_labels,
119
  default_label="unknown",
120
  )
121
  labeller_generator.load()
 
29
  Output: {"classification_task": "The company DavidMovieHouse is a cinema that has been in business for 10 years and has had customers reviews of varying customer groups. Classify the customer reviews as", "labels": ["positive", "negative"]}
30
 
31
  Description: A dataset that focuses on creating neo-ludite discussions about technologies within the AI space.
32
+ Output: {"classification_task": "Neo-ludiite discussions about technologies within the AI space cover from different speaking people. Categorize the discussions into one of the following categories", "labels": ["tech-support", "tech-opposition"]}
33
 
34
  Description: A dataset that covers the articles of a niche sports website called TheSportBlogs that focuses on female sports within the ballsport domain for the US market.
35
  Output: {"classification_task": "TechSportBlogs is a niche sports website that focuses on female sports within the ballsport domain for the US market. Written by different journalists. Determine the category of based on the article using the following categories", "labels": ["basketball", "volleyball", "tennis", "hockey", "baseball", "soccer"]}
 
102
  return textcat_generator
103
 
104
 
105
+ def get_labeller_generator(system_prompt, labels, multi_label):
106
  labeller_generator = TextClassification(
107
  llm=InferenceEndpointsLLM(
108
  model_id=MODEL,
 
115
  ),
116
  context=system_prompt,
117
  available_labels=labels,
118
+ n=len(labels) if multi_label else 1,
119
  default_label="unknown",
120
  )
121
  labeller_generator.load()