aapot commited on
Commit
06e20cd
1 Parent(s): da22eb7

Add pretrain hyperparams

Browse files
Files changed (1) hide show
  1. configure_pretraining.py +142 -0
configure_pretraining.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Config controlling hyperparameters for pre-training ELECTRA."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import os
23
+
24
+
25
+ class PretrainingConfig(object):
26
+ """Defines pre-training hyperparameters."""
27
+
28
+ def __init__(self, model_name, data_dir, **kwargs):
29
+ self.model_name = model_name
30
+ self.debug = False # debug mode for quickly running things
31
+ self.do_train = True # pre-train ELECTRA
32
+ self.do_eval = False # evaluate generator/discriminator on unlabeled data
33
+
34
+ # loss functions
35
+ # train ELECTRA or Electric? if both are false, trains a masked LM like BERT
36
+ self.electra_objective = True
37
+ self.electric_objective = False
38
+ self.gen_weight = 1.0 # masked language modeling / generator loss
39
+ self.disc_weight = 50.0 # discriminator loss
40
+ self.mask_prob = 0.15 # percent of input tokens to mask out / replace
41
+
42
+ # optimization
43
+ self.learning_rate = 2e-4
44
+ self.lr_decay_power = 1.0 # linear weight decay by default
45
+ self.weight_decay_rate = 0.01
46
+ self.num_warmup_steps = 20000
47
+
48
+ # training settings
49
+ self.iterations_per_loop = 200
50
+ self.save_checkpoints_steps = 50000
51
+ self.num_train_steps = 1000000
52
+ self.num_eval_steps = 10000
53
+ self.keep_checkpoint_max = 5 # maximum number of recent checkpoint files to keep;
54
+ # change to 0 or None to keep all checkpoints
55
+
56
+ # model settings
57
+ self.model_size = "base" # one of "small", "base", or "large"
58
+ # override the default transformer hparams for the provided model size; see
59
+ # modeling.BertConfig for the possible hparams and util.training_utils for
60
+ # the defaults
61
+ self.model_hparam_overrides = (
62
+ kwargs["model_hparam_overrides"]
63
+ if "model_hparam_overrides" in kwargs else {})
64
+ self.embedding_size = None # bert hidden size by default
65
+ self.vocab_size = 50265 # number of tokens in the vocabulary
66
+ self.do_lower_case = False # lowercase the input?
67
+
68
+ # generator settings
69
+ self.uniform_generator = False # generator is uniform at random
70
+ self.two_tower_generator = False # generator is a two-tower cloze model
71
+ self.untied_generator_embeddings = False # tie generator/discriminator
72
+ # token embeddings?
73
+ self.untied_generator = True # tie all generator/discriminator weights?
74
+ self.generator_layers = 1.0 # frac of discriminator layers for generator
75
+ self.generator_hidden_size = 0.25 # frac of discrim hidden size for gen
76
+ self.disallow_correct = False # force the generator to sample incorrect
77
+ # tokens (so 15% of tokens are always
78
+ # fake)
79
+ self.temperature = 1.0 # temperature for sampling from generator
80
+
81
+ # batch sizes
82
+ self.max_seq_length = 512
83
+ self.train_batch_size = 256
84
+ self.eval_batch_size = 128
85
+
86
+ # TPU settings
87
+ self.use_tpu = True
88
+ self.num_tpu_cores = 8
89
+ self.tpu_job_name = None
90
+ self.tpu_name = "local" # cloud TPU to use for training
91
+ self.tpu_zone = None # GCE zone where the Cloud TPU is located in
92
+ self.gcp_project = None # project name for the Cloud TPU-enabled project
93
+
94
+ # default locations of data files
95
+ self.pretrain_tfrecords = "/researchdisk/training_dataset_sentences/train_tokenized_512/pretrain_data.tfrecord*"
96
+ self.vocab_file = "/researchdisk/convbert-base-finnish/vocab.txt"
97
+ self.model_dir = "/researchdisk/electra-base-finnish"
98
+ results_dir = os.path.join(self.model_dir, "results")
99
+ self.results_txt = os.path.join(results_dir, "unsup_results.txt")
100
+ self.results_pkl = os.path.join(results_dir, "unsup_results.pkl")
101
+
102
+ # update defaults with passed-in hyperparameters
103
+ self.update(kwargs)
104
+
105
+ self.max_predictions_per_seq = int((self.mask_prob + 0.005) *
106
+ self.max_seq_length)
107
+
108
+ # debug-mode settings
109
+ if self.debug:
110
+ self.train_batch_size = 8
111
+ self.num_train_steps = 20
112
+ self.eval_batch_size = 4
113
+ self.iterations_per_loop = 1
114
+ self.num_eval_steps = 2
115
+
116
+ # defaults for different-sized model
117
+ if self.model_size == "small":
118
+ self.embedding_size = 128
119
+ # Here are the hyperparameters we used for larger models; see Table 6 in the
120
+ # paper for the full hyperparameters
121
+ else:
122
+ self.max_seq_length = 512
123
+ self.learning_rate = 2e-4
124
+ if self.model_size == "base":
125
+ self.embedding_size = 768
126
+ self.generator_hidden_size = 0.33333
127
+ self.train_batch_size = 256
128
+ else:
129
+ self.embedding_size = 1024
130
+ self.mask_prob = 0.25
131
+ self.train_batch_size = 2048
132
+ if self.electric_objective:
133
+ self.two_tower_generator = True # electric requires a two-tower generator
134
+
135
+ # passed-in-arguments override (for example) debug-mode defaults
136
+ self.update(kwargs)
137
+
138
+ def update(self, kwargs):
139
+ for k, v in kwargs.items():
140
+ if k not in self.__dict__:
141
+ raise ValueError("Unknown hparam " + k)
142
+ self.__dict__[k] = v