ZeroCommand commited on
Commit
d81d6fd
1 Parent(s): d2ff920

add conditions to extract labels from dataset

Browse files
Files changed (1) hide show
  1. text_classification.py +9 -1
text_classification.py CHANGED
@@ -15,7 +15,15 @@ def get_labels_and_features_from_dataset(dataset_id, dataset_config, split):
15
  try:
16
  ds = datasets.load_dataset(dataset_id, dataset_config)[split]
17
  dataset_features = ds.features
18
- labels = dataset_features["label"].names
 
 
 
 
 
 
 
 
19
  features = [f for f in dataset_features.keys() if f != "label"]
20
  return labels, features
21
  except Exception as e:
 
15
  try:
16
  ds = datasets.load_dataset(dataset_id, dataset_config)[split]
17
  dataset_features = ds.features
18
+ label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
19
+ if len(label_keys) == 0:
20
+ raise ValueError("Dataset does not have label column")
21
+ if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
22
+ if hasattr(dataset_features[label_keys[0]], 'feature'):
23
+ label_feat = dataset_features[label_keys[0]].feature
24
+ labels = label_feat.names
25
+ else:
26
+ labels = [dataset_features[label_keys[0]].names]
27
  features = [f for f in dataset_features.keys() if f != "label"]
28
  return labels, features
29
  except Exception as e: