Joshua Lochner commited on
Commit
d34e3fe
1 Parent(s): 7dbc778

Fix training arguments dataclasses

Browse files
Files changed (3) hide show
  1. src/shared.py +32 -10
  2. src/train.py +11 -4
  3. src/train_classifier.py +23 -18
src/shared.py CHANGED
@@ -1,5 +1,5 @@
1
  from transformers.trainer_utils import get_last_checkpoint as glc
2
- from transformers import TrainingArguments
3
  import os
4
  from utils import re_findall
5
  import logging
@@ -76,14 +76,15 @@ _SEGMENT_END = END_SEGMENT_TEMPLATE.format(r'\w+')
76
  SEGMENT_MATCH_RE = fr'{_SEGMENT_START}\s*(?P<text>.*?)\s*(?:{_SEGMENT_END}|$)'
77
 
78
 
 
 
 
 
 
 
 
79
  def extract_sponsor_matches(texts):
80
- to_return = []
81
- for text in texts:
82
- if CustomTokens.NO_SEGMENT.value in text:
83
- to_return.append([])
84
- else:
85
- to_return.append(re_findall(SEGMENT_MATCH_RE, text))
86
- return to_return
87
 
88
 
89
  @dataclass
@@ -134,6 +135,22 @@ class DatasetArguments:
134
  },
135
  )
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  def __post_init__(self):
138
  if self.train_file is None or self.validation_file is None:
139
  raise ValueError(
@@ -234,7 +251,7 @@ def load_datasets(dataset_args: DatasetArguments):
234
 
235
 
236
  @dataclass
237
- class CustomTrainingArguments(OutputArguments, TrainingArguments):
238
  seed: Optional[int] = GeneralArguments.__dataclass_fields__['seed']
239
 
240
  num_train_epochs: float = field(
@@ -242,7 +259,7 @@ class CustomTrainingArguments(OutputArguments, TrainingArguments):
242
 
243
  save_steps: int = field(default=5000, metadata={
244
  'help': 'Save checkpoint every X updates steps.'})
245
- eval_steps: int = field(default=5000, metadata={
246
  'help': 'Run an evaluation every X steps.'})
247
  logging_steps: int = field(default=5000, metadata={
248
  'help': 'Log every X updates steps.'})
@@ -311,6 +328,11 @@ class CustomTrainingArguments(OutputArguments, TrainingArguments):
311
  )
312
 
313
 
 
 
 
 
 
314
  logging.basicConfig()
315
  logger = logging.getLogger(__name__)
316
 
 
1
  from transformers.trainer_utils import get_last_checkpoint as glc
2
+ from transformers import Seq2SeqTrainingArguments, TrainingArguments
3
  import os
4
  from utils import re_findall
5
  import logging
 
76
  SEGMENT_MATCH_RE = fr'{_SEGMENT_START}\s*(?P<text>.*?)\s*(?:{_SEGMENT_END}|$)'
77
 
78
 
79
+ def extract_sponsor_matches_from_text(text):
80
+ if CustomTokens.NO_SEGMENT.value in text:
81
+ return []
82
+ else:
83
+ return re_findall(SEGMENT_MATCH_RE, text)
84
+
85
+
86
  def extract_sponsor_matches(texts):
87
+ return list(map(extract_sponsor_matches_from_text, texts))
 
 
 
 
 
 
88
 
89
 
90
  @dataclass
 
135
  },
136
  )
137
 
138
+ c_train_file: Optional[str] = field(
139
+ default='c_train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
140
+ )
141
+ c_validation_file: Optional[str] = field(
142
+ default='c_valid.json',
143
+ metadata={
144
+ 'help': 'An optional input evaluation data file to evaluate the metrics on (a jsonlines file).'
145
+ },
146
+ )
147
+ c_test_file: Optional[str] = field(
148
+ default='c_test.json',
149
+ metadata={
150
+ 'help': 'An optional input test data file to evaluate the metrics on (a jsonlines file).'
151
+ },
152
+ )
153
+
154
  def __post_init__(self):
155
  if self.train_file is None or self.validation_file is None:
156
  raise ValueError(
 
251
 
252
 
253
  @dataclass
254
+ class AdditionalTrainingArguments:
255
  seed: Optional[int] = GeneralArguments.__dataclass_fields__['seed']
256
 
257
  num_train_epochs: float = field(
 
259
 
260
  save_steps: int = field(default=5000, metadata={
261
  'help': 'Save checkpoint every X updates steps.'})
262
+ eval_steps: int = field(default=25000, metadata={
263
  'help': 'Run an evaluation every X steps.'})
264
  logging_steps: int = field(default=5000, metadata={
265
  'help': 'Log every X updates steps.'})
 
328
  )
329
 
330
 
331
+ @dataclass
332
+ class CustomTrainingArguments(OutputArguments, AdditionalTrainingArguments):
333
+ pass
334
+
335
+
336
  logging.basicConfig()
337
  logger = logging.getLogger(__name__)
338
 
src/train.py CHANGED
@@ -1,6 +1,6 @@
1
- from preprocess import PreprocessingDatasetArguments
2
  from shared import (
3
  CustomTokens,
 
4
  prepare_datasets,
5
  load_datasets,
6
  CustomTrainingArguments,
@@ -17,13 +17,15 @@ from transformers import (
17
  DataCollatorForSeq2Seq,
18
  HfArgumentParser,
19
  Seq2SeqTrainer,
 
20
  )
21
 
22
  from transformers.utils import check_min_version
23
  from transformers.utils.versions import require_version
 
24
 
25
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
26
- check_min_version('4.13.0.dev0')
27
  require_version('datasets>=1.8.0',
28
  'To fix: pip install -r requirements.txt')
29
 
@@ -40,6 +42,11 @@ logging.basicConfig(
40
  )
41
 
42
 
 
 
 
 
 
43
  def main():
44
 
45
  # See all possible arguments in src/transformers/training_args.py
@@ -48,8 +55,8 @@ def main():
48
 
49
  hf_parser = HfArgumentParser((
50
  ModelArguments,
51
- PreprocessingDatasetArguments,
52
- CustomTrainingArguments
53
  ))
54
  model_args, dataset_args, training_args = hf_parser.parse_args_into_dataclasses()
55
 
 
 
1
  from shared import (
2
  CustomTokens,
3
+ DatasetArguments,
4
  prepare_datasets,
5
  load_datasets,
6
  CustomTrainingArguments,
 
17
  DataCollatorForSeq2Seq,
18
  HfArgumentParser,
19
  Seq2SeqTrainer,
20
+ Seq2SeqTrainingArguments,
21
  )
22
 
23
  from transformers.utils import check_min_version
24
  from transformers.utils.versions import require_version
25
+ from dataclasses import dataclass
26
 
27
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
28
+ check_min_version('4.17.0')
29
  require_version('datasets>=1.8.0',
30
  'To fix: pip install -r requirements.txt')
31
 
 
42
  )
43
 
44
 
45
+ @dataclass
46
+ class Seq2SeqTrainingArguments(CustomTrainingArguments, Seq2SeqTrainingArguments):
47
+ pass
48
+
49
+
50
  def main():
51
 
52
  # See all possible arguments in src/transformers/training_args.py
 
55
 
56
  hf_parser = HfArgumentParser((
57
  ModelArguments,
58
+ DatasetArguments,
59
+ Seq2SeqTrainingArguments
60
  ))
61
  model_args, dataset_args, training_args = hf_parser.parse_args_into_dataclasses()
62
 
src/train_classifier.py CHANGED
@@ -4,7 +4,7 @@
4
  import logging
5
  import os
6
  import sys
7
- from dataclasses import dataclass, field
8
  from typing import Optional
9
 
10
  import datasets
@@ -16,11 +16,20 @@ from transformers import (
16
  EvalPrediction,
17
  HfArgumentParser,
18
  Trainer,
 
19
  set_seed,
20
  )
21
  from transformers.utils import check_min_version
22
  from transformers.utils.versions import require_version
23
- from shared import CATEGORIES, DatasetArguments, prepare_datasets, load_datasets, CustomTrainingArguments, train_from_checkpoint, get_last_checkpoint
 
 
 
 
 
 
 
 
24
  from model import get_model_tokenizer, ModelArguments
25
 
26
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
@@ -32,23 +41,19 @@ os.environ['WANDB_DISABLED'] = 'true'
32
  logger = logging.getLogger(__name__)
33
 
34
 
 
 
 
 
 
35
  @dataclass
36
  class ClassifierDatasetArguments(DatasetArguments):
37
- train_file: Optional[str] = field(
38
- default='c_train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
39
- )
40
- validation_file: Optional[str] = field(
41
- default='c_valid.json',
42
- metadata={
43
- 'help': 'An optional input evaluation data file to evaluate the metrics on (a jsonlines file).'
44
- },
45
- )
46
- test_file: Optional[str] = field(
47
- default='c_test.json',
48
- metadata={
49
- 'help': 'An optional input test data file to evaluate the metrics on (a jsonlines file).'
50
- },
51
- )
52
 
53
 
54
  def main():
@@ -59,7 +64,7 @@ def main():
59
  hf_parser = HfArgumentParser((
60
  ModelArguments,
61
  ClassifierDatasetArguments,
62
- CustomTrainingArguments
63
  ))
64
  model_args, dataset_args, training_args = hf_parser.parse_args_into_dataclasses()
65
 
 
4
  import logging
5
  import os
6
  import sys
7
+ from dataclasses import dataclass
8
  from typing import Optional
9
 
10
  import datasets
 
16
  EvalPrediction,
17
  HfArgumentParser,
18
  Trainer,
19
+ TrainingArguments,
20
  set_seed,
21
  )
22
  from transformers.utils import check_min_version
23
  from transformers.utils.versions import require_version
24
+ from shared import (
25
+ CATEGORIES,
26
+ DatasetArguments,
27
+ prepare_datasets,
28
+ load_datasets,
29
+ CustomTrainingArguments,
30
+ train_from_checkpoint,
31
+ get_last_checkpoint
32
+ )
33
  from model import get_model_tokenizer, ModelArguments
34
 
35
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 
41
  logger = logging.getLogger(__name__)
42
 
43
 
44
+ @dataclass
45
+ class ClassifierTrainingArguments(CustomTrainingArguments, TrainingArguments):
46
+ pass
47
+
48
+
49
  @dataclass
50
  class ClassifierDatasetArguments(DatasetArguments):
51
+ train_file: Optional[str] = DatasetArguments.__dataclass_fields__[
52
+ 'c_train_file']
53
+ validation_file: Optional[str] = DatasetArguments.__dataclass_fields__[
54
+ 'c_validation_file']
55
+ test_file: Optional[str] = DatasetArguments.__dataclass_fields__[
56
+ 'c_test_file']
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  def main():
 
64
  hf_parser = HfArgumentParser((
65
  ModelArguments,
66
  ClassifierDatasetArguments,
67
+ ClassifierTrainingArguments
68
  ))
69
  model_args, dataset_args, training_args = hf_parser.parse_args_into_dataclasses()
70