JeremiahZ commited on
Commit
30cedfa
·
verified ·
1 Parent(s): bd6bc92

Upload run_classification.py

Browse files
Files changed (1) hide show
  1. run_classification.py +763 -0
run_classification.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2020 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ Finetuning the library models for text classification."""
17
+ # You can also adapt this script on your own text classification task. Pointers for this are left as comments.
18
+
19
+ import logging
20
+ import os
21
+ import random
22
+ import sys
23
+ import warnings
24
+ from dataclasses import dataclass, field
25
+ from typing import List, Optional
26
+
27
+ import datasets
28
+ import evaluate
29
+ import numpy as np
30
+ from datasets import Value, load_dataset
31
+
32
+ import transformers
33
+ from transformers import (
34
+ AutoConfig,
35
+ AutoModelForSequenceClassification,
36
+ AutoTokenizer,
37
+ DataCollatorWithPadding,
38
+ EvalPrediction,
39
+ HfArgumentParser,
40
+ Trainer,
41
+ TrainingArguments,
42
+ default_data_collator,
43
+ set_seed,
44
+ )
45
+ from transformers.trainer_utils import get_last_checkpoint
46
+ from transformers.utils import check_min_version, send_example_telemetry
47
+ from transformers.utils.versions import require_version
48
+
49
+
50
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
51
+ # check_min_version("4.38.0.dev0")
52
+
53
+ require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
54
+
55
+
56
+ logger = logging.getLogger(__name__)
57
+
58
+
59
+ @dataclass
60
+ class DataTrainingArguments:
61
+ """
62
+ Arguments pertaining to what data we are going to input our model for training and eval.
63
+
64
+ Using `HfArgumentParser` we can turn this class
65
+ into argparse arguments to be able to specify them on
66
+ the command line.
67
+ """
68
+
69
+ dataset_name: Optional[str] = field(
70
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
71
+ )
72
+ dataset_config_name: Optional[str] = field(
73
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
74
+ )
75
+ do_regression: bool = field(
76
+ default=None,
77
+ metadata={
78
+ "help": "Whether to do regression instead of classification. If None, will be inferred from the dataset."
79
+ },
80
+ )
81
+ text_column_names: Optional[str] = field(
82
+ default=None,
83
+ metadata={
84
+ "help": (
85
+ "The name of the text column in the input dataset or a CSV/JSON file. "
86
+ 'If not specified, will use the "sentence" column for single/multi-label classification task.'
87
+ )
88
+ },
89
+ )
90
+ text_column_delimiter: Optional[str] = field(
91
+ default=" ", metadata={"help": "THe delimiter to use to join text columns into a single sentence."}
92
+ )
93
+ train_split_name: Optional[str] = field(
94
+ default=None,
95
+ metadata={
96
+ "help": 'The name of the train split in the input dataset. If not specified, will use the "train" split when do_train is enabled'
97
+ },
98
+ )
99
+ validation_split_name: Optional[str] = field(
100
+ default=None,
101
+ metadata={
102
+ "help": 'The name of the validation split in the input dataset. If not specified, will use the "validation" split when do_eval is enabled'
103
+ },
104
+ )
105
+ test_split_name: Optional[str] = field(
106
+ default=None,
107
+ metadata={
108
+ "help": 'The name of the test split in the input dataset. If not specified, will use the "test" split when do_predict is enabled'
109
+ },
110
+ )
111
+ remove_splits: Optional[str] = field(
112
+ default=None,
113
+ metadata={"help": "The splits to remove from the dataset. Multiple splits should be separated by commas."},
114
+ )
115
+ remove_columns: Optional[str] = field(
116
+ default=None,
117
+ metadata={"help": "The columns to remove from the dataset. Multiple columns should be separated by commas."},
118
+ )
119
+ label_column_name: Optional[str] = field(
120
+ default=None,
121
+ metadata={
122
+ "help": (
123
+ "The name of the label column in the input dataset or a CSV/JSON file. "
124
+ 'If not specified, will use the "label" column for single/multi-label classification task'
125
+ )
126
+ },
127
+ )
128
+ max_seq_length: int = field(
129
+ default=128,
130
+ metadata={
131
+ "help": (
132
+ "The maximum total input sequence length after tokenization. Sequences longer "
133
+ "than this will be truncated, sequences shorter will be padded."
134
+ )
135
+ },
136
+ )
137
+ overwrite_cache: bool = field(
138
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
139
+ )
140
+ pad_to_max_length: bool = field(
141
+ default=True,
142
+ metadata={
143
+ "help": (
144
+ "Whether to pad all samples to `max_seq_length`. "
145
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
146
+ )
147
+ },
148
+ )
149
+ shuffle_train_dataset: bool = field(
150
+ default=False, metadata={"help": "Whether to shuffle the train dataset or not."}
151
+ )
152
+ shuffle_seed: int = field(
153
+ default=42, metadata={"help": "Random seed that will be used to shuffle the train dataset."}
154
+ )
155
+ max_train_samples: Optional[int] = field(
156
+ default=None,
157
+ metadata={
158
+ "help": (
159
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
160
+ "value if set."
161
+ )
162
+ },
163
+ )
164
+ max_eval_samples: Optional[int] = field(
165
+ default=None,
166
+ metadata={
167
+ "help": (
168
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
169
+ "value if set."
170
+ )
171
+ },
172
+ )
173
+ max_predict_samples: Optional[int] = field(
174
+ default=None,
175
+ metadata={
176
+ "help": (
177
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
178
+ "value if set."
179
+ )
180
+ },
181
+ )
182
+ metric_name: Optional[str] = field(default=None, metadata={"help": "The metric to use for evaluation."})
183
+ train_file: Optional[str] = field(
184
+ default=None, metadata={"help": "A csv or a json file containing the training data."}
185
+ )
186
+ validation_file: Optional[str] = field(
187
+ default=None, metadata={"help": "A csv or a json file containing the validation data."}
188
+ )
189
+ test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
190
+
191
+ def __post_init__(self):
192
+ if self.dataset_name is None:
193
+ if self.train_file is None or self.validation_file is None:
194
+ raise ValueError(" training/validation file or a dataset name.")
195
+
196
+ train_extension = self.train_file.split(".")[-1]
197
+ assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
198
+ validation_extension = self.validation_file.split(".")[-1]
199
+ assert (
200
+ validation_extension == train_extension
201
+ ), "`validation_file` should have the same extension (csv or json) as `train_file`."
202
+
203
+
204
+ @dataclass
205
+ class ModelArguments:
206
+ """
207
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
208
+ """
209
+
210
+ model_name_or_path: str = field(
211
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
212
+ )
213
+ config_name: Optional[str] = field(
214
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
215
+ )
216
+ tokenizer_name: Optional[str] = field(
217
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
218
+ )
219
+ cache_dir: Optional[str] = field(
220
+ default=None,
221
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
222
+ )
223
+ use_fast_tokenizer: bool = field(
224
+ default=True,
225
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
226
+ )
227
+ model_revision: str = field(
228
+ default="main",
229
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
230
+ )
231
+ token: str = field(
232
+ default=None,
233
+ metadata={
234
+ "help": (
235
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
236
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
237
+ )
238
+ },
239
+ )
240
+ use_auth_token: bool = field(
241
+ default=None,
242
+ metadata={
243
+ "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
244
+ },
245
+ )
246
+ trust_remote_code: bool = field(
247
+ default=False,
248
+ metadata={
249
+ "help": (
250
+ "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
251
+ "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
252
+ "execute code present on the Hub on your local machine."
253
+ )
254
+ },
255
+ )
256
+ ignore_mismatched_sizes: bool = field(
257
+ default=False,
258
+ metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
259
+ )
260
+
261
+
262
+ def get_label_list(raw_dataset, split="train") -> List[str]:
263
+ """Get the list of labels from a multi-label dataset"""
264
+
265
+ if isinstance(raw_dataset[split]["label"][0], list):
266
+ label_list = [label for sample in raw_dataset[split]["label"] for label in sample]
267
+ label_list = list(set(label_list))
268
+ else:
269
+ label_list = raw_dataset[split].unique("label")
270
+ # we will treat the label list as a list of string instead of int, consistent with model.config.label2id
271
+ label_list = [str(label) for label in label_list]
272
+ return label_list
273
+
274
+
275
+ def main():
276
+ # See all possible arguments in src/transformers/training_args.py
277
+ # or by passing the --help flag to this script.
278
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
279
+
280
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
281
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
282
+ # If we pass only one argument to the script and it's the path to a json file,
283
+ # let's parse it to get our arguments.
284
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
285
+ else:
286
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
287
+
288
+ if model_args.use_auth_token is not None:
289
+ warnings.warn(
290
+ "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
291
+ FutureWarning,
292
+ )
293
+ if model_args.token is not None:
294
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
295
+ model_args.token = model_args.use_auth_token
296
+
297
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
298
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
299
+ # send_example_telemetry("run_classification", model_args, data_args)
300
+
301
+ # Setup logging
302
+ logging.basicConfig(
303
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
304
+ datefmt="%m/%d/%Y %H:%M:%S",
305
+ handlers=[logging.StreamHandler(sys.stdout)],
306
+ )
307
+
308
+ if training_args.should_log:
309
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
310
+ transformers.utils.logging.set_verbosity_info()
311
+
312
+ log_level = training_args.get_process_log_level()
313
+ logger.setLevel(log_level)
314
+ datasets.utils.logging.set_verbosity(log_level)
315
+ transformers.utils.logging.set_verbosity(log_level)
316
+ transformers.utils.logging.enable_default_handler()
317
+ transformers.utils.logging.enable_explicit_format()
318
+
319
+ # Log on each process the small summary:
320
+ logger.warning(
321
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
322
+ + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
323
+ )
324
+ logger.info(f"Training/evaluation parameters {training_args}")
325
+
326
+ # Detecting last checkpoint.
327
+ last_checkpoint = None
328
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
329
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
330
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
331
+ raise ValueError(
332
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
333
+ "Use --overwrite_output_dir to overcome."
334
+ )
335
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
336
+ logger.info(
337
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
338
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
339
+ )
340
+
341
+ # Set seed before initializing model.
342
+ set_seed(training_args.seed)
343
+
344
+ # Get the datasets: you can either provide your own CSV/JSON training and evaluation files, or specify a dataset name
345
+ # to load from huggingface/datasets. In ether case, you can specify a the key of the column(s) containing the text and
346
+ # the key of the column containing the label. If multiple columns are specified for the text, they will be joined together
347
+ # for the actual text value.
348
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
349
+ # download the dataset.
350
+ if data_args.dataset_name is not None:
351
+ # Downloading and loading a dataset from the hub.
352
+ raw_datasets = load_dataset(
353
+ data_args.dataset_name,
354
+ data_args.dataset_config_name,
355
+ cache_dir=model_args.cache_dir,
356
+ token=model_args.token,
357
+ )
358
+ # Try print some info about the dataset
359
+ logger.info(f"Dataset loaded: {raw_datasets}")
360
+ logger.info(raw_datasets)
361
+ else:
362
+ # Loading a dataset from your local files.
363
+ # CSV/JSON training and evaluation files are needed.
364
+ data_files = {"train": data_args.train_file, "validation": data_args.validation_file}
365
+
366
+ # Get the test dataset: you can provide your own CSV/JSON test file
367
+ if training_args.do_predict:
368
+ if data_args.test_file is not None:
369
+ train_extension = data_args.train_file.split(".")[-1]
370
+ test_extension = data_args.test_file.split(".")[-1]
371
+ assert (
372
+ test_extension == train_extension
373
+ ), "`test_file` should have the same extension (csv or json) as `train_file`."
374
+ data_files["test"] = data_args.test_file
375
+ else:
376
+ raise ValueError("Need either a dataset name or a test file for `do_predict`.")
377
+
378
+ for key in data_files.keys():
379
+ logger.info(f"load a local file for {key}: {data_files[key]}")
380
+
381
+ if data_args.train_file.endswith(".csv"):
382
+ # Loading a dataset from local csv files
383
+ raw_datasets = load_dataset(
384
+ "csv",
385
+ data_files=data_files,
386
+ cache_dir=model_args.cache_dir,
387
+ token=model_args.token,
388
+ )
389
+ else:
390
+ # Loading a dataset from local json files
391
+ raw_datasets = load_dataset(
392
+ "json",
393
+ data_files=data_files,
394
+ cache_dir=model_args.cache_dir,
395
+ token=model_args.token,
396
+ )
397
+
398
+ # See more about loading any type of standard or custom dataset at
399
+ # https://huggingface.co/docs/datasets/loading_datasets.
400
+
401
+ if data_args.remove_splits is not None:
402
+ for split in data_args.remove_splits.split(","):
403
+ logger.info(f"removing split {split}")
404
+ raw_datasets.pop(split)
405
+
406
+ if data_args.train_split_name is not None:
407
+ logger.info(f"using {data_args.train_split_name} as train set")
408
+ raw_datasets["train"] = raw_datasets[data_args.train_split_name]
409
+ raw_datasets.pop(data_args.train_split_name)
410
+
411
+ if data_args.validation_split_name is not None:
412
+ logger.info(f"using {data_args.validation_split_name} as validation set")
413
+ raw_datasets["validation"] = raw_datasets[data_args.validation_split_name]
414
+ raw_datasets.pop(data_args.validation_split_name)
415
+
416
+ if data_args.test_split_name is not None:
417
+ logger.info(f"using {data_args.test_split_name} as test set")
418
+ raw_datasets["test"] = raw_datasets[data_args.test_split_name]
419
+ raw_datasets.pop(data_args.test_split_name)
420
+
421
+ if data_args.remove_columns is not None:
422
+ for split in raw_datasets.keys():
423
+ for column in data_args.remove_columns.split(","):
424
+ logger.info(f"removing column {column} from split {split}")
425
+ raw_datasets[split].remove_columns(column)
426
+
427
+ if data_args.label_column_name is not None and data_args.label_column_name != "label":
428
+ for key in raw_datasets.keys():
429
+ raw_datasets[key] = raw_datasets[key].rename_column(data_args.label_column_name, "label")
430
+
431
+ # Trying to have good defaults here, don't hesitate to tweak to your needs.
432
+
433
+ is_regression = (
434
+ raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
435
+ if data_args.do_regression is None
436
+ else data_args.do_regression
437
+ )
438
+
439
+ is_multi_label = False
440
+ if is_regression:
441
+ label_list = None
442
+ num_labels = 1
443
+ # regession requires float as label type, let's cast it if needed
444
+ for split in raw_datasets.keys():
445
+ if raw_datasets[split].features["label"].dtype not in ["float32", "float64"]:
446
+ logger.warning(
447
+ f"Label type for {split} set to float32, was {raw_datasets[split].features['label'].dtype}"
448
+ )
449
+ features = raw_datasets[split].features
450
+ features.update({"label": Value("float32")})
451
+ try:
452
+ raw_datasets[split] = raw_datasets[split].cast(features)
453
+ except TypeError as error:
454
+ logger.error(
455
+ f"Unable to cast {split} set to float32, please check the labels are correct, or maybe try with --do_regression=False"
456
+ )
457
+ raise error
458
+
459
+ else: # classification
460
+ if raw_datasets["train"].features["label"].dtype == "list": # multi-label classification
461
+ is_multi_label = True
462
+ logger.info("Label type is list, doing multi-label classification")
463
+ # Trying to find the number of labels in a multi-label classification task
464
+ # We have to deal with common cases that labels appear in the training set but not in the validation/test set.
465
+ # So we build the label list from the union of labels in train/val/test.
466
+ label_list = get_label_list(raw_datasets, split="train")
467
+ for split in ["validation", "test"]:
468
+ if split in raw_datasets:
469
+ val_or_test_labels = get_label_list(raw_datasets, split=split)
470
+ diff = set(val_or_test_labels).difference(set(label_list))
471
+ if len(diff) > 0:
472
+ # add the labels that appear in val/test but not in train, throw a warning
473
+ logger.warning(
474
+ f"Labels {diff} in {split} set but not in training set, adding them to the label list"
475
+ )
476
+ label_list += list(diff)
477
+ # if label is -1, we throw a warning and remove it from the label list
478
+ for label in label_list:
479
+ if label == -1:
480
+ logger.warning("Label -1 found in label list, removing it.")
481
+ label_list.remove(label)
482
+
483
+ label_list.sort()
484
+ num_labels = len(label_list)
485
+ if num_labels <= 1:
486
+ raise ValueError("You need more than one label to do classification.")
487
+
488
+ # Load pretrained model and tokenizer
489
+ # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
490
+ # download model & vocab.
491
+ config = AutoConfig.from_pretrained(
492
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
493
+ num_labels=num_labels,
494
+ finetuning_task="text-classification",
495
+ cache_dir=model_args.cache_dir,
496
+ revision=model_args.model_revision,
497
+ token=model_args.token,
498
+ trust_remote_code=model_args.trust_remote_code,
499
+ )
500
+
501
+ if is_regression:
502
+ config.problem_type = "regression"
503
+ logger.info("setting problem type to regression")
504
+ elif is_multi_label:
505
+ config.problem_type = "multi_label_classification"
506
+ logger.info("setting problem type to multi label classification")
507
+ else:
508
+ config.problem_type = "single_label_classification"
509
+ logger.info("setting problem type to single label classification")
510
+
511
+ tokenizer = AutoTokenizer.from_pretrained(
512
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
513
+ cache_dir=model_args.cache_dir,
514
+ use_fast=model_args.use_fast_tokenizer,
515
+ revision=model_args.model_revision,
516
+ token=model_args.token,
517
+ trust_remote_code=model_args.trust_remote_code,
518
+ )
519
+ model = AutoModelForSequenceClassification.from_pretrained(
520
+ model_args.model_name_or_path,
521
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
522
+ config=config,
523
+ cache_dir=model_args.cache_dir,
524
+ revision=model_args.model_revision,
525
+ token=model_args.token,
526
+ trust_remote_code=model_args.trust_remote_code,
527
+ ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
528
+ )
529
+
530
+ # Padding strategy
531
+ if data_args.pad_to_max_length:
532
+ padding = "max_length"
533
+ else:
534
+ # We will pad later, dynamically at batch creation, to the max sequence length in each batch
535
+ padding = False
536
+
537
+ # for training ,we will update the config with label infos,
538
+ # if do_train is not set, we will use the label infos in the config
539
+ if training_args.do_train and not is_regression: # classification, training
540
+ label_to_id = {v: i for i, v in enumerate(label_list)}
541
+ # update config with label infos
542
+ if model.config.label2id != label_to_id:
543
+ logger.warning(
544
+ "The label2id key in the model config.json is not equal to the label2id key of this "
545
+ "run. You can ignore this if you are doing finetuning."
546
+ )
547
+ model.config.label2id = label_to_id
548
+ model.config.id2label = {id: label for label, id in label_to_id.items()}
549
+ elif not is_regression: # classification, but not training
550
+ logger.info("using label infos in the model config")
551
+ logger.info("label2id: {}".format(model.config.label2id))
552
+ label_to_id = model.config.label2id
553
+ else: # regression
554
+ label_to_id = None
555
+
556
+ if data_args.max_seq_length > tokenizer.model_max_length:
557
+ logger.warning(
558
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the "
559
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
560
+ )
561
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
562
+
563
+ def multi_labels_to_ids(labels: List[str]) -> List[float]:
564
+ ids = [0.0] * len(label_to_id) # BCELoss requires float as target type
565
+ for label in labels:
566
+ ids[label_to_id[label]] = 1.0
567
+ return ids
568
+
569
+ def preprocess_function(examples):
570
+ if data_args.text_column_names is not None:
571
+ text_column_names = data_args.text_column_names.split(",")
572
+ # join together text columns into "sentence" column
573
+ examples["sentence"] = examples[text_column_names[0]]
574
+ for column in text_column_names[1:]:
575
+ for i in range(len(examples[column])):
576
+ examples["sentence"][i] += data_args.text_column_delimiter + examples[column][i]
577
+ # Tokenize the texts
578
+ result = tokenizer(examples["sentence"], padding=padding, max_length=max_seq_length, truncation=True)
579
+ if label_to_id is not None and "label" in examples:
580
+ if is_multi_label:
581
+ result["label"] = [multi_labels_to_ids(l) for l in examples["label"]]
582
+ else:
583
+ result["label"] = [(label_to_id[str(l)] if l != -1 else -1) for l in examples["label"]]
584
+ return result
585
+
586
+ # Running the preprocessing pipeline on all the datasets
587
+ with training_args.main_process_first(desc="dataset map pre-processing"):
588
+ raw_datasets = raw_datasets.map(
589
+ preprocess_function,
590
+ batched=True,
591
+ load_from_cache_file=not data_args.overwrite_cache,
592
+ desc="Running tokenizer on dataset",
593
+ )
594
+
595
+ if training_args.do_train:
596
+ if "train" not in raw_datasets:
597
+ raise ValueError("--do_train requires a train dataset.")
598
+ train_dataset = raw_datasets["train"]
599
+ if data_args.shuffle_train_dataset:
600
+ logger.info("Shuffling the training dataset")
601
+ train_dataset = train_dataset.shuffle(seed=data_args.shuffle_seed)
602
+ if data_args.max_train_samples is not None:
603
+ max_train_samples = min(len(train_dataset), data_args.max_train_samples)
604
+ train_dataset = train_dataset.select(range(max_train_samples))
605
+
606
+ if training_args.do_eval:
607
+ if "validation" not in raw_datasets and "validation_matched" not in raw_datasets:
608
+ if "test" not in raw_datasets and "test_matched" not in raw_datasets:
609
+ raise ValueError("--do_eval requires a validation or test dataset if validation is not defined.")
610
+ else:
611
+ logger.warning("Validation dataset not found. Falling back to test dataset for validation.")
612
+ eval_dataset = raw_datasets["test"]
613
+ else:
614
+ eval_dataset = raw_datasets["validation"]
615
+
616
+ if data_args.max_eval_samples is not None:
617
+ max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
618
+ eval_dataset = eval_dataset.select(range(max_eval_samples))
619
+
620
+ if training_args.do_predict or data_args.test_file is not None:
621
+ if "test" not in raw_datasets:
622
+ raise ValueError("--do_predict requires a test dataset")
623
+ predict_dataset = raw_datasets["test"]
624
+ # remove label column if it exists
625
+ if data_args.max_predict_samples is not None:
626
+ max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
627
+ predict_dataset = predict_dataset.select(range(max_predict_samples))
628
+
629
+ # Log a few random samples from the training set:
630
+ if training_args.do_train:
631
+ for index in random.sample(range(len(train_dataset)), 3):
632
+ logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
633
+
634
+ if data_args.metric_name is not None:
635
+ metric = (
636
+ evaluate.load(data_args.metric_name, config_name="multilabel", cache_dir=model_args.cache_dir)
637
+ if is_multi_label
638
+ else evaluate.load(data_args.metric_name, cache_dir=model_args.cache_dir)
639
+ )
640
+ logger.info(f"Using metric {data_args.metric_name} for evaluation.")
641
+ else:
642
+ if is_regression:
643
+ metric = evaluate.load("mse", cache_dir=model_args.cache_dir)
644
+ logger.info("Using mean squared error (mse) as regression score, you can use --metric_name to overwrite.")
645
+ else:
646
+ if is_multi_label:
647
+ metric = evaluate.load("f1", config_name="multilabel", cache_dir=model_args.cache_dir)
648
+ logger.info(
649
+ "Using multilabel F1 for multi-label classification task, you can use --metric_name to overwrite."
650
+ )
651
+ else:
652
+ metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
653
+ logger.info("Using accuracy as classification score, you can use --metric_name to overwrite.")
654
+
655
+ def compute_metrics(p: EvalPrediction):
656
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
657
+ if is_regression:
658
+ preds = np.squeeze(preds)
659
+ result = metric.compute(predictions=preds, references=p.label_ids)
660
+ elif is_multi_label:
661
+ preds = np.array([np.where(p > 0, 1, 0) for p in preds]) # convert logits to multi-hot encoding
662
+ # Micro F1 is commonly used in multi-label classification
663
+ result = metric.compute(predictions=preds, references=p.label_ids, average="micro")
664
+ else:
665
+ preds = np.argmax(preds, axis=1)
666
+ result = metric.compute(predictions=preds, references=p.label_ids)
667
+ if len(result) > 1:
668
+ result["combined_score"] = np.mean(list(result.values())).item()
669
+ return result
670
+
671
+ # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if
672
+ # we already did the padding.
673
+ if data_args.pad_to_max_length:
674
+ data_collator = default_data_collator
675
+ elif training_args.fp16:
676
+ data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
677
+ else:
678
+ data_collator = None
679
+
680
+ # Initialize our Trainer
681
+ trainer = Trainer(
682
+ model=model,
683
+ args=training_args,
684
+ train_dataset=train_dataset if training_args.do_train else None,
685
+ eval_dataset=eval_dataset if training_args.do_eval else None,
686
+ compute_metrics=compute_metrics,
687
+ tokenizer=tokenizer,
688
+ data_collator=data_collator,
689
+ )
690
+
691
+ # Training
692
+ if training_args.do_train:
693
+ checkpoint = None
694
+ if training_args.resume_from_checkpoint is not None:
695
+ checkpoint = training_args.resume_from_checkpoint
696
+ elif last_checkpoint is not None:
697
+ checkpoint = last_checkpoint
698
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
699
+ metrics = train_result.metrics
700
+ max_train_samples = (
701
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
702
+ )
703
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
704
+ trainer.save_model() # Saves the tokenizer too for easy upload
705
+ trainer.log_metrics("train", metrics)
706
+ trainer.save_metrics("train", metrics)
707
+ trainer.save_state()
708
+
709
+ # Evaluation
710
+ if training_args.do_eval:
711
+ logger.info("*** Evaluate ***")
712
+ metrics = trainer.evaluate(eval_dataset=eval_dataset)
713
+ max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
714
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
715
+ trainer.log_metrics("eval", metrics)
716
+ trainer.save_metrics("eval", metrics)
717
+
718
+ if training_args.do_predict:
719
+ logger.info("*** Predict ***")
720
+ # Removing the `label` columns if exists because it might contains -1 and Trainer won't like that.
721
+ if "label" in predict_dataset.features:
722
+ predict_dataset = predict_dataset.remove_columns("label")
723
+ predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
724
+ if is_regression:
725
+ predictions = np.squeeze(predictions)
726
+ elif is_multi_label:
727
+ # Convert logits to multi-hot encoding. We compare the logits to 0 instead of 0.5, because the sigmoid is not applied.
728
+ # You can also pass `preprocess_logits_for_metrics=lambda logits, labels: nn.functional.sigmoid(logits)` to the Trainer
729
+ # and set p > 0.5 below (less efficient in this case)
730
+ predictions = np.array([np.where(p > 0, 1, 0) for p in predictions])
731
+ else:
732
+ predictions = np.argmax(predictions, axis=1)
733
+ output_predict_file = os.path.join(training_args.output_dir, "predict_results.txt")
734
+ if trainer.is_world_process_zero():
735
+ with open(output_predict_file, "w") as writer:
736
+ logger.info("***** Predict results *****")
737
+ writer.write("index\tprediction\n")
738
+ for index, item in enumerate(predictions):
739
+ if is_regression:
740
+ writer.write(f"{index}\t{item:3.3f}\n")
741
+ elif is_multi_label:
742
+ # recover from multi-hot encoding
743
+ item = [label_list[i] for i in range(len(item)) if item[i] == 1]
744
+ writer.write(f"{index}\t{item}\n")
745
+ else:
746
+ item = label_list[item]
747
+ writer.write(f"{index}\t{item}\n")
748
+ logger.info("Predict results saved at {}".format(output_predict_file))
749
+ kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"}
750
+
751
+ if training_args.push_to_hub:
752
+ trainer.push_to_hub(**kwargs)
753
+ else:
754
+ trainer.create_model_card(**kwargs)
755
+
756
+
757
+ def _mp_fn(index):
758
+ # For xla_spawn (TPUs)
759
+ main()
760
+
761
+
762
+ if __name__ == "__main__":
763
+ main()