Spaces:
Running
Running
Joshua Lochner
commited on
Commit
•
d34e3fe
1
Parent(s):
7dbc778
Fix training arguments dataclasses
Browse files- src/shared.py +32 -10
- src/train.py +11 -4
- src/train_classifier.py +23 -18
src/shared.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from transformers.trainer_utils import get_last_checkpoint as glc
|
2 |
-
from transformers import TrainingArguments
|
3 |
import os
|
4 |
from utils import re_findall
|
5 |
import logging
|
@@ -76,14 +76,15 @@ _SEGMENT_END = END_SEGMENT_TEMPLATE.format(r'\w+')
|
|
76 |
SEGMENT_MATCH_RE = fr'{_SEGMENT_START}\s*(?P<text>.*?)\s*(?:{_SEGMENT_END}|$)'
|
77 |
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
def extract_sponsor_matches(texts):
|
80 |
-
|
81 |
-
for text in texts:
|
82 |
-
if CustomTokens.NO_SEGMENT.value in text:
|
83 |
-
to_return.append([])
|
84 |
-
else:
|
85 |
-
to_return.append(re_findall(SEGMENT_MATCH_RE, text))
|
86 |
-
return to_return
|
87 |
|
88 |
|
89 |
@dataclass
|
@@ -134,6 +135,22 @@ class DatasetArguments:
|
|
134 |
},
|
135 |
)
|
136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
def __post_init__(self):
|
138 |
if self.train_file is None or self.validation_file is None:
|
139 |
raise ValueError(
|
@@ -234,7 +251,7 @@ def load_datasets(dataset_args: DatasetArguments):
|
|
234 |
|
235 |
|
236 |
@dataclass
|
237 |
-
class
|
238 |
seed: Optional[int] = GeneralArguments.__dataclass_fields__['seed']
|
239 |
|
240 |
num_train_epochs: float = field(
|
@@ -242,7 +259,7 @@ class CustomTrainingArguments(OutputArguments, TrainingArguments):
|
|
242 |
|
243 |
save_steps: int = field(default=5000, metadata={
|
244 |
'help': 'Save checkpoint every X updates steps.'})
|
245 |
-
eval_steps: int = field(default=
|
246 |
'help': 'Run an evaluation every X steps.'})
|
247 |
logging_steps: int = field(default=5000, metadata={
|
248 |
'help': 'Log every X updates steps.'})
|
@@ -311,6 +328,11 @@ class CustomTrainingArguments(OutputArguments, TrainingArguments):
|
|
311 |
)
|
312 |
|
313 |
|
|
|
|
|
|
|
|
|
|
|
314 |
logging.basicConfig()
|
315 |
logger = logging.getLogger(__name__)
|
316 |
|
|
|
1 |
from transformers.trainer_utils import get_last_checkpoint as glc
|
2 |
+
from transformers import Seq2SeqTrainingArguments, TrainingArguments
|
3 |
import os
|
4 |
from utils import re_findall
|
5 |
import logging
|
|
|
76 |
SEGMENT_MATCH_RE = fr'{_SEGMENT_START}\s*(?P<text>.*?)\s*(?:{_SEGMENT_END}|$)'
|
77 |
|
78 |
|
79 |
+
def extract_sponsor_matches_from_text(text):
|
80 |
+
if CustomTokens.NO_SEGMENT.value in text:
|
81 |
+
return []
|
82 |
+
else:
|
83 |
+
return re_findall(SEGMENT_MATCH_RE, text)
|
84 |
+
|
85 |
+
|
86 |
def extract_sponsor_matches(texts):
|
87 |
+
return list(map(extract_sponsor_matches_from_text, texts))
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
|
90 |
@dataclass
|
|
|
135 |
},
|
136 |
)
|
137 |
|
138 |
+
c_train_file: Optional[str] = field(
|
139 |
+
default='c_train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
|
140 |
+
)
|
141 |
+
c_validation_file: Optional[str] = field(
|
142 |
+
default='c_valid.json',
|
143 |
+
metadata={
|
144 |
+
'help': 'An optional input evaluation data file to evaluate the metrics on (a jsonlines file).'
|
145 |
+
},
|
146 |
+
)
|
147 |
+
c_test_file: Optional[str] = field(
|
148 |
+
default='c_test.json',
|
149 |
+
metadata={
|
150 |
+
'help': 'An optional input test data file to evaluate the metrics on (a jsonlines file).'
|
151 |
+
},
|
152 |
+
)
|
153 |
+
|
154 |
def __post_init__(self):
|
155 |
if self.train_file is None or self.validation_file is None:
|
156 |
raise ValueError(
|
|
|
251 |
|
252 |
|
253 |
@dataclass
|
254 |
+
class AdditionalTrainingArguments:
|
255 |
seed: Optional[int] = GeneralArguments.__dataclass_fields__['seed']
|
256 |
|
257 |
num_train_epochs: float = field(
|
|
|
259 |
|
260 |
save_steps: int = field(default=5000, metadata={
|
261 |
'help': 'Save checkpoint every X updates steps.'})
|
262 |
+
eval_steps: int = field(default=25000, metadata={
|
263 |
'help': 'Run an evaluation every X steps.'})
|
264 |
logging_steps: int = field(default=5000, metadata={
|
265 |
'help': 'Log every X updates steps.'})
|
|
|
328 |
)
|
329 |
|
330 |
|
331 |
+
@dataclass
|
332 |
+
class CustomTrainingArguments(OutputArguments, AdditionalTrainingArguments):
|
333 |
+
pass
|
334 |
+
|
335 |
+
|
336 |
logging.basicConfig()
|
337 |
logger = logging.getLogger(__name__)
|
338 |
|
src/train.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
from preprocess import PreprocessingDatasetArguments
|
2 |
from shared import (
|
3 |
CustomTokens,
|
|
|
4 |
prepare_datasets,
|
5 |
load_datasets,
|
6 |
CustomTrainingArguments,
|
@@ -17,13 +17,15 @@ from transformers import (
|
|
17 |
DataCollatorForSeq2Seq,
|
18 |
HfArgumentParser,
|
19 |
Seq2SeqTrainer,
|
|
|
20 |
)
|
21 |
|
22 |
from transformers.utils import check_min_version
|
23 |
from transformers.utils.versions import require_version
|
|
|
24 |
|
25 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
26 |
-
check_min_version('4.
|
27 |
require_version('datasets>=1.8.0',
|
28 |
'To fix: pip install -r requirements.txt')
|
29 |
|
@@ -40,6 +42,11 @@ logging.basicConfig(
|
|
40 |
)
|
41 |
|
42 |
|
|
|
|
|
|
|
|
|
|
|
43 |
def main():
|
44 |
|
45 |
# See all possible arguments in src/transformers/training_args.py
|
@@ -48,8 +55,8 @@ def main():
|
|
48 |
|
49 |
hf_parser = HfArgumentParser((
|
50 |
ModelArguments,
|
51 |
-
|
52 |
-
|
53 |
))
|
54 |
model_args, dataset_args, training_args = hf_parser.parse_args_into_dataclasses()
|
55 |
|
|
|
|
|
1 |
from shared import (
|
2 |
CustomTokens,
|
3 |
+
DatasetArguments,
|
4 |
prepare_datasets,
|
5 |
load_datasets,
|
6 |
CustomTrainingArguments,
|
|
|
17 |
DataCollatorForSeq2Seq,
|
18 |
HfArgumentParser,
|
19 |
Seq2SeqTrainer,
|
20 |
+
Seq2SeqTrainingArguments,
|
21 |
)
|
22 |
|
23 |
from transformers.utils import check_min_version
|
24 |
from transformers.utils.versions import require_version
|
25 |
+
from dataclasses import dataclass
|
26 |
|
27 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
28 |
+
check_min_version('4.17.0')
|
29 |
require_version('datasets>=1.8.0',
|
30 |
'To fix: pip install -r requirements.txt')
|
31 |
|
|
|
42 |
)
|
43 |
|
44 |
|
45 |
+
@dataclass
|
46 |
+
class Seq2SeqTrainingArguments(CustomTrainingArguments, Seq2SeqTrainingArguments):
|
47 |
+
pass
|
48 |
+
|
49 |
+
|
50 |
def main():
|
51 |
|
52 |
# See all possible arguments in src/transformers/training_args.py
|
|
|
55 |
|
56 |
hf_parser = HfArgumentParser((
|
57 |
ModelArguments,
|
58 |
+
DatasetArguments,
|
59 |
+
Seq2SeqTrainingArguments
|
60 |
))
|
61 |
model_args, dataset_args, training_args = hf_parser.parse_args_into_dataclasses()
|
62 |
|
src/train_classifier.py
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
import logging
|
5 |
import os
|
6 |
import sys
|
7 |
-
from dataclasses import dataclass
|
8 |
from typing import Optional
|
9 |
|
10 |
import datasets
|
@@ -16,11 +16,20 @@ from transformers import (
|
|
16 |
EvalPrediction,
|
17 |
HfArgumentParser,
|
18 |
Trainer,
|
|
|
19 |
set_seed,
|
20 |
)
|
21 |
from transformers.utils import check_min_version
|
22 |
from transformers.utils.versions import require_version
|
23 |
-
from shared import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
from model import get_model_tokenizer, ModelArguments
|
25 |
|
26 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
@@ -32,23 +41,19 @@ os.environ['WANDB_DISABLED'] = 'true'
|
|
32 |
logger = logging.getLogger(__name__)
|
33 |
|
34 |
|
|
|
|
|
|
|
|
|
|
|
35 |
@dataclass
|
36 |
class ClassifierDatasetArguments(DatasetArguments):
|
37 |
-
train_file: Optional[str] =
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
'help': 'An optional input evaluation data file to evaluate the metrics on (a jsonlines file).'
|
44 |
-
},
|
45 |
-
)
|
46 |
-
test_file: Optional[str] = field(
|
47 |
-
default='c_test.json',
|
48 |
-
metadata={
|
49 |
-
'help': 'An optional input test data file to evaluate the metrics on (a jsonlines file).'
|
50 |
-
},
|
51 |
-
)
|
52 |
|
53 |
|
54 |
def main():
|
@@ -59,7 +64,7 @@ def main():
|
|
59 |
hf_parser = HfArgumentParser((
|
60 |
ModelArguments,
|
61 |
ClassifierDatasetArguments,
|
62 |
-
|
63 |
))
|
64 |
model_args, dataset_args, training_args = hf_parser.parse_args_into_dataclasses()
|
65 |
|
|
|
4 |
import logging
|
5 |
import os
|
6 |
import sys
|
7 |
+
from dataclasses import dataclass
|
8 |
from typing import Optional
|
9 |
|
10 |
import datasets
|
|
|
16 |
EvalPrediction,
|
17 |
HfArgumentParser,
|
18 |
Trainer,
|
19 |
+
TrainingArguments,
|
20 |
set_seed,
|
21 |
)
|
22 |
from transformers.utils import check_min_version
|
23 |
from transformers.utils.versions import require_version
|
24 |
+
from shared import (
|
25 |
+
CATEGORIES,
|
26 |
+
DatasetArguments,
|
27 |
+
prepare_datasets,
|
28 |
+
load_datasets,
|
29 |
+
CustomTrainingArguments,
|
30 |
+
train_from_checkpoint,
|
31 |
+
get_last_checkpoint
|
32 |
+
)
|
33 |
from model import get_model_tokenizer, ModelArguments
|
34 |
|
35 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
|
|
41 |
logger = logging.getLogger(__name__)
|
42 |
|
43 |
|
44 |
+
@dataclass
|
45 |
+
class ClassifierTrainingArguments(CustomTrainingArguments, TrainingArguments):
|
46 |
+
pass
|
47 |
+
|
48 |
+
|
49 |
@dataclass
|
50 |
class ClassifierDatasetArguments(DatasetArguments):
|
51 |
+
train_file: Optional[str] = DatasetArguments.__dataclass_fields__[
|
52 |
+
'c_train_file']
|
53 |
+
validation_file: Optional[str] = DatasetArguments.__dataclass_fields__[
|
54 |
+
'c_validation_file']
|
55 |
+
test_file: Optional[str] = DatasetArguments.__dataclass_fields__[
|
56 |
+
'c_test_file']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
|
59 |
def main():
|
|
|
64 |
hf_parser = HfArgumentParser((
|
65 |
ModelArguments,
|
66 |
ClassifierDatasetArguments,
|
67 |
+
ClassifierTrainingArguments
|
68 |
))
|
69 |
model_args, dataset_args, training_args = hf_parser.parse_args_into_dataclasses()
|
70 |
|