lewtun HF staff commited on
Commit
fcdf4a0
1 Parent(s): 4cef17f

Add support for image classification

Browse files
Files changed (2) hide show
  1. app.py +40 -4
  2. utils.py +3 -2
app.py CHANGED
@@ -31,11 +31,12 @@ DATASETS_PREVIEW_API = os.getenv("DATASETS_PREVIEW_API")
31
  TASK_TO_ID = {
32
  "binary_classification": 1,
33
  "multi_class_classification": 2,
34
- # "multi_label_classification": 3, # Not fully supported in AutoTrain
35
  "entity_extraction": 4,
36
  "extractive_question_answering": 5,
37
  "translation": 6,
38
  "summarization": 8,
 
 
39
  }
40
 
41
  TASK_TO_DEFAULT_METRICS = {
@@ -50,8 +51,22 @@ TASK_TO_DEFAULT_METRICS = {
50
  "extractive_question_answering": [],
51
  "translation": ["sacrebleu"],
52
  "summarization": ["rouge1", "rouge2", "rougeL", "rougeLsum"],
 
 
 
 
 
 
 
53
  }
54
 
 
 
 
 
 
 
 
55
  SUPPORTED_TASKS = list(TASK_TO_ID.keys())
56
 
57
  # Extracted from utils.get_supported_metrics
@@ -355,6 +370,27 @@ with st.expander("Advanced configuration"):
355
  col_mapping[question_col] = "question"
356
  col_mapping[answers_text_col] = "answers.text"
357
  col_mapping[answers_start_col] = "answers.answer_start"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
  # Select metrics
360
  st.markdown("**Select metrics**")
@@ -408,9 +444,9 @@ with st.form(key="form"):
408
  "proj_name": f"eval-project-{project_id}",
409
  "task": TASK_TO_ID[selected_task],
410
  "config": {
411
- "language": "en"
412
- if selected_task != "translation"
413
- else "en2de", # Need this dummy pair to enable translation
414
  "max_models": 5,
415
  "instance": {
416
  "provider": "aws",
 
31
  TASK_TO_ID = {
32
  "binary_classification": 1,
33
  "multi_class_classification": 2,
 
34
  "entity_extraction": 4,
35
  "extractive_question_answering": 5,
36
  "translation": 6,
37
  "summarization": 8,
38
+ "image_binary_classification": 17,
39
+ "image_multi_class_classification": 18,
40
  }
41
 
42
  TASK_TO_DEFAULT_METRICS = {
 
51
  "extractive_question_answering": [],
52
  "translation": ["sacrebleu"],
53
  "summarization": ["rouge1", "rouge2", "rougeL", "rougeLsum"],
54
+ "image_binary_classification": ["f1", "precision", "recall", "auc", "accuracy"],
55
+ "image_multi_class_classification": [
56
+ "f1",
57
+ "precision",
58
+ "recall",
59
+ "accuracy",
60
+ ],
61
  }
62
 
63
+ AUTOTRAIN_TASK_TO_LANG = {
64
+ "translation": "en2de",
65
+ "image_binary_classification": "unk",
66
+ "image_multi_class_classification": "unk",
67
+ }
68
+
69
+
70
  SUPPORTED_TASKS = list(TASK_TO_ID.keys())
71
 
72
  # Extracted from utils.get_supported_metrics
 
370
  col_mapping[question_col] = "question"
371
  col_mapping[answers_text_col] = "answers.text"
372
  col_mapping[answers_start_col] = "answers.answer_start"
373
+ elif selected_task in ["image_binary_classification", "image_multi_class_classification"]:
374
+ with col1:
375
+ st.markdown("`image` column")
376
+ st.text("")
377
+ st.text("")
378
+ st.text("")
379
+ st.text("")
380
+ st.markdown("`target` column")
381
+ with col2:
382
+ image_col = st.selectbox(
383
+ "This column should contain the images to be classified",
384
+ col_names,
385
+ index=col_names.index(get_key(metadata[0]["col_mapping"], "image")) if metadata is not None else 0,
386
+ )
387
+ target_col = st.selectbox(
388
+ "This column should contain the labels associated with the images",
389
+ col_names,
390
+ index=col_names.index(get_key(metadata[0]["col_mapping"], "target")) if metadata is not None else 0,
391
+ )
392
+ col_mapping[image_col] = "image"
393
+ col_mapping[target_col] = "target"
394
 
395
  # Select metrics
396
  st.markdown("**Select metrics**")
 
444
  "proj_name": f"eval-project-{project_id}",
445
  "task": TASK_TO_ID[selected_task],
446
  "config": {
447
+ "language": AUTOTRAIN_TASK_TO_LANG[selected_task]
448
+ if selected_task in AUTOTRAIN_TASK_TO_LANG
449
+ else "en",
450
  "max_models": 5,
451
  "instance": {
452
  "provider": "aws",
utils.py CHANGED
@@ -11,14 +11,15 @@ from tqdm import tqdm
11
  AUTOTRAIN_TASK_TO_HUB_TASK = {
12
  "binary_classification": "text-classification",
13
  "multi_class_classification": "text-classification",
14
- # "multi_label_classification": "text-classification", # Not fully supported in AutoTrain
15
  "entity_extraction": "token-classification",
16
  "extractive_question_answering": "question-answering",
17
  "translation": "translation",
18
  "summarization": "summarization",
19
- # "single_column_regression": 10,
 
20
  }
21
 
 
22
  HUB_TASK_TO_AUTOTRAIN_TASK = {v: k for k, v in AUTOTRAIN_TASK_TO_HUB_TASK.items()}
23
  LOGS_REPO = "evaluation-job-logs"
24
 
 
11
  AUTOTRAIN_TASK_TO_HUB_TASK = {
12
  "binary_classification": "text-classification",
13
  "multi_class_classification": "text-classification",
 
14
  "entity_extraction": "token-classification",
15
  "extractive_question_answering": "question-answering",
16
  "translation": "translation",
17
  "summarization": "summarization",
18
+ "image_binary_classification": "image-classification",
19
+ "image_multi_class_classification": "image-classification",
20
  }
21
 
22
+
23
  HUB_TASK_TO_AUTOTRAIN_TASK = {v: k for k, v in AUTOTRAIN_TASK_TO_HUB_TASK.items()}
24
  LOGS_REPO = "evaluation-job-logs"
25