rishiraj commited on
Commit
7d19bc8
·
1 Parent(s): 8d1ee8d

Upload configs.py

Browse files
Files changed (1) hide show
  1. configs.py +272 -0
configs.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import dataclasses
17
+ import os
18
+ import sys
19
+ from dataclasses import dataclass, field
20
+ from typing import Any, Dict, List, NewType, Optional, Tuple
21
+
22
+ import transformers
23
+ from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser
24
+
25
+
26
+ MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
27
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
28
+
29
+
30
+ DataClassType = NewType("DataClassType", Any)
31
+
32
+
33
+ class H4ArgumentParser(HfArgumentParser):
34
+ def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]:
35
+ """
36
+ Parse a YAML file and overwrite the default/loaded values with the values provided to the command line.
37
+
38
+ Args:
39
+ yaml_arg (`str`):
40
+ The path to the config file used
41
+ other_args (`List[str]`, *optional`):
42
+ A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2'].
43
+
44
+ Returns:
45
+ [`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line
46
+ """
47
+ arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg))
48
+
49
+ outputs = []
50
+ # strip other args list into dict of key-value pairs
51
+ other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args}
52
+ used_args = {}
53
+
54
+ # overwrite the default/loaded value with the value provided to the command line
55
+ # adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327
56
+ for data_yaml, data_class in zip(arg_list, self.dataclass_types):
57
+ keys = {f.name for f in dataclasses.fields(data_yaml) if f.init}
58
+ inputs = {k: v for k, v in vars(data_yaml).items() if k in keys}
59
+ for arg, val in other_args.items():
60
+ # add only if in keys
61
+ if arg in keys:
62
+ base_type = data_yaml.__dataclass_fields__[arg].type
63
+ inputs[arg] = val
64
+
65
+ # cast type for ints, floats (default to strings)
66
+ if base_type in [int, float]:
67
+ inputs[arg] = base_type(val)
68
+
69
+ if base_type == List[str]:
70
+ inputs[arg] = [str(v) for v in val.split(",")]
71
+
72
+ # bool of a non-empty string is True, so we manually check for bools
73
+ if base_type == bool:
74
+ if val in ["true", "True"]:
75
+ inputs[arg] = True
76
+ else:
77
+ inputs[arg] = False
78
+
79
+ # add to used-args so we can check if double add
80
+ if arg not in used_args:
81
+ used_args[arg] = val
82
+ else:
83
+ raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior")
84
+
85
+ obj = data_class(**inputs)
86
+ outputs.append(obj)
87
+
88
+ return outputs
89
+
90
+ def parse(self) -> DataClassType | Tuple[DataClassType]:
91
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
92
+ # If we pass only one argument to the script and it's the path to a YAML file,
93
+ # let's parse it to get our arguments.
94
+ output = self.parse_yaml_file(os.path.abspath(sys.argv[1]))
95
+ # parse command line args and yaml file
96
+ elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"):
97
+ output = self.parse_yaml_and_args(os.path.abspath(sys.argv[1]), sys.argv[2:])
98
+ # parse command line args only
99
+ else:
100
+ output = self.parse_args_into_dataclasses()
101
+
102
+ if len(output) == 1:
103
+ output = output[0]
104
+ return output
105
+
106
+
107
+ @dataclass
108
+ class ModelArguments:
109
+ """
110
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
111
+ """
112
+
113
+ base_model_revision: Optional[str] = field(
114
+ default=None,
115
+ metadata={"help": ("The base model checkpoint for weights initialization with PEFT adatpers.")},
116
+ )
117
+ model_name_or_path: Optional[str] = field(
118
+ default=None,
119
+ metadata={
120
+ "help": (
121
+ "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
122
+ )
123
+ },
124
+ )
125
+ model_revision: str = field(
126
+ default="main",
127
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
128
+ )
129
+ model_code_revision: str = field(default=None, metadata={"help": "The branch of the IFT model"})
130
+ torch_dtype: Optional[str] = field(
131
+ default=None,
132
+ metadata={
133
+ "help": (
134
+ "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
135
+ "dtype will be automatically derived from the model's weights."
136
+ ),
137
+ "choices": ["auto", "bfloat16", "float16", "float32"],
138
+ },
139
+ )
140
+ trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
141
+ use_flash_attention_2: bool = field(
142
+ default=False,
143
+ metadata={
144
+ "help": (
145
+ "Whether to use flash attention 2. You must install this manually by running `pip install flash-attn --no-build-isolation`"
146
+ )
147
+ },
148
+ )
149
+ use_peft: bool = field(
150
+ default=False,
151
+ metadata={"help": ("Whether to use PEFT or not for training.")},
152
+ )
153
+ lora_r: Optional[int] = field(
154
+ default=16,
155
+ metadata={"help": ("LoRA R value.")},
156
+ )
157
+ lora_alpha: Optional[int] = field(
158
+ default=32,
159
+ metadata={"help": ("LoRA alpha.")},
160
+ )
161
+ lora_dropout: Optional[float] = field(
162
+ default=0.05,
163
+ metadata={"help": ("LoRA dropout.")},
164
+ )
165
+ lora_target_modules: Optional[List[str]] = field(
166
+ default=None,
167
+ metadata={"help": ("LoRA target modules.")},
168
+ )
169
+ lora_modules_to_save: Optional[List[str]] = field(
170
+ default=None,
171
+ metadata={"help": ("Model layers to unfreeze & train")},
172
+ )
173
+ load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision"})
174
+ load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision"})
175
+
176
+ bnb_4bit_quant_type: Optional[str] = field(
177
+ default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}
178
+ )
179
+ use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
180
+
181
+ def __post_init__(self):
182
+ if self.load_in_8bit and self.load_in_4bit:
183
+ raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
184
+
185
+
186
+ @dataclass
187
+ class DataArguments:
188
+ """
189
+ Arguments pertaining to what data we are going to input our model for training and eval.
190
+ """
191
+
192
+ chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
193
+ dataset_mixer: Optional[Dict[str, float]] = field(
194
+ default=None,
195
+ metadata={"help": ("Datasets and their proportions to be used for training ift/rl.")},
196
+ )
197
+ dataset_splits: Optional[List[str]] = field(
198
+ default_factory=lambda: ["train", "test"],
199
+ metadata={"help": ("List of train test splits to use in the dataset")},
200
+ )
201
+ max_train_samples: Optional[int] = field(
202
+ default=None,
203
+ metadata={
204
+ "help": (
205
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
206
+ "value if set."
207
+ )
208
+ },
209
+ )
210
+ max_eval_samples: Optional[int] = field(
211
+ default=None,
212
+ metadata={
213
+ "help": (
214
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
215
+ "value if set."
216
+ )
217
+ },
218
+ )
219
+ preprocessing_num_workers: Optional[int] = field(
220
+ default=None,
221
+ metadata={"help": "The number of processes to use for the preprocessing."},
222
+ )
223
+ truncation_side: Optional[str] = field(
224
+ default=None, metadata={"help": "Truncation side to use for the tokenizer."}
225
+ )
226
+
227
+
228
+ @dataclass
229
+ class SFTConfig(transformers.TrainingArguments):
230
+ """
231
+ Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
232
+ """
233
+
234
+ max_seq_length: Optional[int] = field(
235
+ default=None,
236
+ metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")},
237
+ )
238
+ logging_first_step: bool = field(
239
+ default=True,
240
+ metadata={"help": ("Whether to log and evaluate the first global_step or not.")},
241
+ )
242
+ optim: Optional[str] = field(default="adamw_torch")
243
+
244
+
245
+ @dataclass
246
+ class DPOConfig(transformers.TrainingArguments):
247
+ """
248
+ Arguments related to the DPO training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
249
+ """
250
+
251
+ beta: Optional[float] = field(
252
+ default=0.1,
253
+ metadata={"help": "The beta factor in DPO loss. Higher beta means less divergence from the initial policy."},
254
+ )
255
+ hub_model_revision: Optional[str] = field(
256
+ default="main",
257
+ metadata={"help": ("The Hub model branch to push the model to.")},
258
+ )
259
+ logging_first_step: bool = field(
260
+ default=True,
261
+ metadata={"help": ("Whether to log and evaluate the first global_step or not.")},
262
+ )
263
+ max_prompt_length: Optional[int] = field(
264
+ default=None,
265
+ metadata={"help": ("For DPO, the maximum length of the prompt to use for conditioning the model.")},
266
+ )
267
+ max_length: Optional[int] = field(
268
+ default=None,
269
+ metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")},
270
+ )
271
+ optim: Optional[str] = field(default="rmsprop")
272
+ remove_unused_columns: bool = field(default=False)