rishiraj commited on
Commit
129f129
·
1 Parent(s): 7d19bc8

Update configs.py

Browse files
Files changed (1) hide show
  1. configs.py +1 -213
configs.py CHANGED
@@ -14,174 +14,9 @@
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
  import dataclasses
17
- import os
18
- import sys
19
  from dataclasses import dataclass, field
20
  from typing import Any, Dict, List, NewType, Optional, Tuple
21
 
22
- import transformers
23
- from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser
24
-
25
-
26
- MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
27
- MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
28
-
29
-
30
- DataClassType = NewType("DataClassType", Any)
31
-
32
-
33
- class H4ArgumentParser(HfArgumentParser):
34
- def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]:
35
- """
36
- Parse a YAML file and overwrite the default/loaded values with the values provided to the command line.
37
-
38
- Args:
39
- yaml_arg (`str`):
40
- The path to the config file used
41
- other_args (`List[str]`, *optional`):
42
- A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2'].
43
-
44
- Returns:
45
- [`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line
46
- """
47
- arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg))
48
-
49
- outputs = []
50
- # strip other args list into dict of key-value pairs
51
- other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args}
52
- used_args = {}
53
-
54
- # overwrite the default/loaded value with the value provided to the command line
55
- # adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327
56
- for data_yaml, data_class in zip(arg_list, self.dataclass_types):
57
- keys = {f.name for f in dataclasses.fields(data_yaml) if f.init}
58
- inputs = {k: v for k, v in vars(data_yaml).items() if k in keys}
59
- for arg, val in other_args.items():
60
- # add only if in keys
61
- if arg in keys:
62
- base_type = data_yaml.__dataclass_fields__[arg].type
63
- inputs[arg] = val
64
-
65
- # cast type for ints, floats (default to strings)
66
- if base_type in [int, float]:
67
- inputs[arg] = base_type(val)
68
-
69
- if base_type == List[str]:
70
- inputs[arg] = [str(v) for v in val.split(",")]
71
-
72
- # bool of a non-empty string is True, so we manually check for bools
73
- if base_type == bool:
74
- if val in ["true", "True"]:
75
- inputs[arg] = True
76
- else:
77
- inputs[arg] = False
78
-
79
- # add to used-args so we can check if double add
80
- if arg not in used_args:
81
- used_args[arg] = val
82
- else:
83
- raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior")
84
-
85
- obj = data_class(**inputs)
86
- outputs.append(obj)
87
-
88
- return outputs
89
-
90
- def parse(self) -> DataClassType | Tuple[DataClassType]:
91
- if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
92
- # If we pass only one argument to the script and it's the path to a YAML file,
93
- # let's parse it to get our arguments.
94
- output = self.parse_yaml_file(os.path.abspath(sys.argv[1]))
95
- # parse command line args and yaml file
96
- elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"):
97
- output = self.parse_yaml_and_args(os.path.abspath(sys.argv[1]), sys.argv[2:])
98
- # parse command line args only
99
- else:
100
- output = self.parse_args_into_dataclasses()
101
-
102
- if len(output) == 1:
103
- output = output[0]
104
- return output
105
-
106
-
107
- @dataclass
108
- class ModelArguments:
109
- """
110
- Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
111
- """
112
-
113
- base_model_revision: Optional[str] = field(
114
- default=None,
115
- metadata={"help": ("The base model checkpoint for weights initialization with PEFT adatpers.")},
116
- )
117
- model_name_or_path: Optional[str] = field(
118
- default=None,
119
- metadata={
120
- "help": (
121
- "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
122
- )
123
- },
124
- )
125
- model_revision: str = field(
126
- default="main",
127
- metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
128
- )
129
- model_code_revision: str = field(default=None, metadata={"help": "The branch of the IFT model"})
130
- torch_dtype: Optional[str] = field(
131
- default=None,
132
- metadata={
133
- "help": (
134
- "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
135
- "dtype will be automatically derived from the model's weights."
136
- ),
137
- "choices": ["auto", "bfloat16", "float16", "float32"],
138
- },
139
- )
140
- trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
141
- use_flash_attention_2: bool = field(
142
- default=False,
143
- metadata={
144
- "help": (
145
- "Whether to use flash attention 2. You must install this manually by running `pip install flash-attn --no-build-isolation`"
146
- )
147
- },
148
- )
149
- use_peft: bool = field(
150
- default=False,
151
- metadata={"help": ("Whether to use PEFT or not for training.")},
152
- )
153
- lora_r: Optional[int] = field(
154
- default=16,
155
- metadata={"help": ("LoRA R value.")},
156
- )
157
- lora_alpha: Optional[int] = field(
158
- default=32,
159
- metadata={"help": ("LoRA alpha.")},
160
- )
161
- lora_dropout: Optional[float] = field(
162
- default=0.05,
163
- metadata={"help": ("LoRA dropout.")},
164
- )
165
- lora_target_modules: Optional[List[str]] = field(
166
- default=None,
167
- metadata={"help": ("LoRA target modules.")},
168
- )
169
- lora_modules_to_save: Optional[List[str]] = field(
170
- default=None,
171
- metadata={"help": ("Model layers to unfreeze & train")},
172
- )
173
- load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision"})
174
- load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision"})
175
-
176
- bnb_4bit_quant_type: Optional[str] = field(
177
- default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}
178
- )
179
- use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
180
-
181
- def __post_init__(self):
182
- if self.load_in_8bit and self.load_in_4bit:
183
- raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
184
-
185
 
186
  @dataclass
187
  class DataArguments:
@@ -222,51 +57,4 @@ class DataArguments:
222
  )
223
  truncation_side: Optional[str] = field(
224
  default=None, metadata={"help": "Truncation side to use for the tokenizer."}
225
- )
226
-
227
-
228
- @dataclass
229
- class SFTConfig(transformers.TrainingArguments):
230
- """
231
- Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
232
- """
233
-
234
- max_seq_length: Optional[int] = field(
235
- default=None,
236
- metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")},
237
- )
238
- logging_first_step: bool = field(
239
- default=True,
240
- metadata={"help": ("Whether to log and evaluate the first global_step or not.")},
241
- )
242
- optim: Optional[str] = field(default="adamw_torch")
243
-
244
-
245
- @dataclass
246
- class DPOConfig(transformers.TrainingArguments):
247
- """
248
- Arguments related to the DPO training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
249
- """
250
-
251
- beta: Optional[float] = field(
252
- default=0.1,
253
- metadata={"help": "The beta factor in DPO loss. Higher beta means less divergence from the initial policy."},
254
- )
255
- hub_model_revision: Optional[str] = field(
256
- default="main",
257
- metadata={"help": ("The Hub model branch to push the model to.")},
258
- )
259
- logging_first_step: bool = field(
260
- default=True,
261
- metadata={"help": ("Whether to log and evaluate the first global_step or not.")},
262
- )
263
- max_prompt_length: Optional[int] = field(
264
- default=None,
265
- metadata={"help": ("For DPO, the maximum length of the prompt to use for conditioning the model.")},
266
- )
267
- max_length: Optional[int] = field(
268
- default=None,
269
- metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")},
270
- )
271
- optim: Optional[str] = field(default="rmsprop")
272
- remove_unused_columns: bool = field(default=False)
 
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
  import dataclasses
 
 
17
  from dataclasses import dataclass, field
18
  from typing import Any, Dict, List, NewType, Optional, Tuple
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  @dataclass
22
  class DataArguments:
 
57
  )
58
  truncation_side: Optional[str] = field(
59
  default=None, metadata={"help": "Truncation side to use for the tokenizer."}
60
+ )