File size: 10,278 Bytes
4ea2eae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import os
import sys
from dataclasses import dataclass, field
from typing import Any, Dict, List, NewType, Optional, Tuple
import transformers
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
DataClassType = NewType("DataClassType", Any)
class H4ArgumentParser(HfArgumentParser):
def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]:
"""
Parse a YAML file and overwrite the default/loaded values with the values provided to the command line.
Args:
yaml_arg (`str`):
The path to the config file used
other_args (`List[str]`, *optional`):
A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2'].
Returns:
[`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line
"""
arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg))
outputs = []
# strip other args list into dict of key-value pairs
other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args}
used_args = {}
# overwrite the default/loaded value with the value provided to the command line
# adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327
for data_yaml, data_class in zip(arg_list, self.dataclass_types):
keys = {f.name for f in dataclasses.fields(data_yaml) if f.init}
inputs = {k: v for k, v in vars(data_yaml).items() if k in keys}
for arg, val in other_args.items():
# add only if in keys
if arg in keys:
base_type = data_yaml.__dataclass_fields__[arg].type
inputs[arg] = val
# cast type for ints, floats (default to strings)
if base_type in [int, float]:
inputs[arg] = base_type(val)
if base_type == List[str]:
inputs[arg] = [str(v) for v in val.split(",")]
# bool of a non-empty string is True, so we manually check for bools
if base_type == bool:
if val in ["true", "True"]:
inputs[arg] = True
else:
inputs[arg] = False
# add to used-args so we can check if double add
if arg not in used_args:
used_args[arg] = val
else:
raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior")
obj = data_class(**inputs)
outputs.append(obj)
return outputs
def parse(self) -> DataClassType | Tuple[DataClassType]:
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
# If we pass only one argument to the script and it's the path to a YAML file,
# let's parse it to get our arguments.
output = self.parse_yaml_file(os.path.abspath(sys.argv[1]))
# parse command line args and yaml file
elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"):
output = self.parse_yaml_and_args(os.path.abspath(sys.argv[1]), sys.argv[2:])
# parse command line args only
else:
output = self.parse_args_into_dataclasses()
if len(output) == 1:
output = output[0]
return output
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
"""
base_model_revision: Optional[str] = field(
default=None,
metadata={"help": ("The base model checkpoint for weights initialization with PEFT adatpers.")},
)
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
)
},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
model_code_revision: str = field(default=None, metadata={"help": "The branch of the IFT model"})
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
use_flash_attention_2: bool = field(
default=False,
metadata={
"help": (
"Whether to use flash attention 2. You must install this manually by running `pip install flash-attn --no-build-isolation`"
)
},
)
use_peft: bool = field(
default=False,
metadata={"help": ("Whether to use PEFT or not for training.")},
)
lora_r: Optional[int] = field(
default=16,
metadata={"help": ("LoRA R value.")},
)
lora_alpha: Optional[int] = field(
default=32,
metadata={"help": ("LoRA alpha.")},
)
lora_dropout: Optional[float] = field(
default=0.05,
metadata={"help": ("LoRA dropout.")},
)
lora_target_modules: Optional[List[str]] = field(
default=None,
metadata={"help": ("LoRA target modules.")},
)
lora_modules_to_save: Optional[List[str]] = field(
default=None,
metadata={"help": ("Model layers to unfreeze & train")},
)
load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision"})
load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision"})
bnb_4bit_quant_type: Optional[str] = field(
default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}
)
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
def __post_init__(self):
if self.load_in_8bit and self.load_in_4bit:
raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
@dataclass
class DataArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
dataset_mixer: Optional[Dict[str, float]] = field(
default=None,
metadata={"help": ("Datasets and their proportions to be used for training ift/rl.")},
)
dataset_splits: Optional[List[str]] = field(
default_factory=lambda: ["train", "test"],
metadata={"help": ("List of train test splits to use in the dataset")},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
truncation_side: Optional[str] = field(
default=None, metadata={"help": "Truncation side to use for the tokenizer."}
)
@dataclass
class SFTConfig(transformers.TrainingArguments):
"""
Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
"""
max_seq_length: Optional[int] = field(
default=None,
metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")},
)
logging_first_step: bool = field(
default=True,
metadata={"help": ("Whether to log and evaluate the first global_step or not.")},
)
optim: Optional[str] = field(default="adamw_torch")
@dataclass
class DPOConfig(transformers.TrainingArguments):
"""
Arguments related to the DPO training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
"""
beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta factor in DPO loss. Higher beta means less divergence from the initial policy."},
)
hub_model_revision: Optional[str] = field(
default="main",
metadata={"help": ("The Hub model branch to push the model to.")},
)
logging_first_step: bool = field(
default=True,
metadata={"help": ("Whether to log and evaluate the first global_step or not.")},
)
max_prompt_length: Optional[int] = field(
default=None,
metadata={"help": ("For DPO, the maximum length of the prompt to use for conditioning the model.")},
)
max_length: Optional[int] = field(
default=None,
metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")},
)
optim: Optional[str] = field(default="rmsprop")
remove_unused_columns: bool = field(default=False)
loss_type: Optional[str] = field(default="sigmoid", metadata={"help": ("The loss type for DPO.")})
|