import os |
import json |
import torch |
from typing import Any, Dict, List, Literal, Optional |
from dataclasses import asdict, dataclass, field |
@dataclass |
class DatasetAttr: |
load_from: str |
dataset_name: Optional[str] = None |
file_name: Optional[str] = None |
file_sha1: Optional[str] = None |
def __repr__(self) -> str: |
if self.dataset_name is not None: |
return self.dataset_name |
else: |
return self.file_name |
def __post_init__(self): |
self.prompt_column = "instruction" |
self.query_column = "input" |
self.response_column = "output" |
self.history_column = None |
@dataclass |
class ModelArguments: |
""" |
Arguments pertaining to which model/config/tokenizer we are going to fine-tune. |
""" |
model_name_or_path: str = field( |
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."} |
) |
cache_dir: Optional[str] = field( |
default=None, |
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."} |
) |
use_fast_tokenizer: Optional[bool] = field( |
default=False, |
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} |
) |
use_auth_token: Optional[bool] = field( |
default=False, |
metadata={"help": "Will use the token generated when running `huggingface-cli login`."} |
) |
model_revision: Optional[str] = field( |
default="main", |
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} |
) |
quantization_bit: Optional[int] = field( |
default=None, |
metadata={"help": "The number of bits to quantize the model."} |
) |
quantization_type: Optional[Literal["fp4", "nf4"]] = field( |
default="nf4", |
metadata={"help": "Quantization data type to use in int4 training."} |
) |
double_quantization: Optional[bool] = field( |
default=True, |
metadata={"help": "Whether to use double quantization in int4 training or not."} |
) |
compute_dtype: Optional[torch.dtype] = field( |
default=None, |
metadata={"help": "Used in quantization configs. Do not specify this argument manually."} |
) |
checkpoint_dir: Optional[str] = field( |
default=None, |
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} |
) |
reward_model: Optional[str] = field( |
default=None, |
metadata={"help": "Path to the directory containing the checkpoints of the reward model."} |
) |
resume_lora_training: Optional[bool] = field( |
default=True, |
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} |
) |
plot_loss: Optional[bool] = field( |
default=False, |
metadata={"help": "Whether to plot the training loss after fine-tuning or not."} |
) |
def __post_init__(self): |
if self.checkpoint_dir is not None: |
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] |
if self.quantization_bit is not None: |
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." |
@dataclass |
class DataTrainingArguments: |
""" |
Arguments pertaining to what data we are going to input our model for training and evaluation. |
""" |
dataset: Optional[str] = field( |
default="alpaca_zh", |
metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."} |
) |
dataset_dir: Optional[str] = field( |
default="data", |
metadata={"help": "The name of the folder containing datasets."} |
) |
split: Optional[str] = field( |
default="train", |
metadata={"help": "Which dataset split to use for training and evaluation."} |
) |
overwrite_cache: Optional[bool] = field( |
default=False, |
metadata={"help": "Overwrite the cached training and evaluation sets."} |
) |
preprocessing_num_workers: Optional[int] = field( |
default=None, |
metadata={"help": "The number of processes to use for the preprocessing."} |
) |
max_source_length: Optional[int] = field( |
default=512, |
metadata={"help": "The maximum total input sequence length after tokenization."} |
) |
max_target_length: Optional[int] = field( |
default=512, |
metadata={"help": "The maximum total output sequence length after tokenization."} |
) |
max_samples: Optional[int] = field( |
default=None, |
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."} |
) |
eval_num_beams: Optional[int] = field( |
default=None, |
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"} |
) |
ignore_pad_token_for_loss: Optional[bool] = field( |
default=True, |
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} |
) |
source_prefix: Optional[str] = field( |
default=None, |
metadata={"help": "A prefix to add before every source text (useful for T5 models)."} |
) |
dev_ratio: Optional[float] = field( |
default=0, |
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."} |
) |
prompt_template: Optional[str] = field( |
default="alpaca", |
metadata={"help": "Which template to use for constructing prompts in training and inference."} |
) |
def __post_init__(self): |
dataset_names = [ds.strip() for ds in self.dataset.split(",")] |
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: |
dataset_info = json.load(f) |
self.dataset_list: List[DatasetAttr] = [] |
for name in dataset_names: |
if name not in dataset_info: |
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) |
if "hf_hub_url" in dataset_info[name]: |
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) |
elif "script_url" in dataset_info[name]: |
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) |
else: |
dataset_attr = DatasetAttr( |
"file", |
file_name=dataset_info[name]["file_name"], |
file_sha1=dataset_info[name].get("file_sha1", None) |
) |
if "columns" in dataset_info[name]: |
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None) |
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None) |
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None) |
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None) |
self.dataset_list.append(dataset_attr) |
@dataclass |
class FinetuningArguments: |
""" |
Arguments pertaining to which techniques we are going to fine-tuning with. |
""" |
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field( |
default="lora", |
metadata={"help": "Which fine-tuning method to use."} |
) |
num_layer_trainable: Optional[int] = field( |
default=3, |
metadata={"help": "Number of trainable layers for Freeze fine-tuning."} |
) |
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( |
default="mlp", |
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \ |
LLaMA choices: [\"mlp\", \"self_attn\"], \ |
BLOOM choices: [\"mlp\", \"self_attention\"], \ |
Baichuan choices: [\"mlp\", \"self_attn\"]"} |
) |
lora_rank: Optional[int] = field( |
default=8, |
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} |
) |
lora_alpha: Optional[float] = field( |
default=32.0, |
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."} |
) |
lora_dropout: Optional[float] = field( |
default=0.1, |
metadata={"help": "Dropout rate for the LoRA fine-tuning."} |
) |
lora_target: Optional[str] = field( |
default="q_proj,v_proj", |
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \ |
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ |
BLOOM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ |
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"} |
) |
def __post_init__(self): |
if isinstance(self.lora_target, str): |
self.lora_target = [target.strip() for target in self.lora_target.split(",")] |
if self.num_layer_trainable > 0: |
trainable_layer_ids = [27 - k for k in range(self.num_layer_trainable)] |
else: |
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)] |
self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids] |
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method." |
def save_to_json(self, json_path: str): |
"""Saves the content of this instance in JSON format inside `json_path`.""" |
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" |
with open(json_path, "w", encoding="utf-8") as f: |
f.write(json_string) |
@classmethod |
def load_from_json(cls, json_path: str): |
"""Creates an instance from the content of `json_path`.""" |
with open(json_path, "r", encoding="utf-8") as f: |
text = f.read() |
return cls(**json.loads(text)) |
@dataclass |
class GeneratingArguments: |
""" |
Arguments pertaining to specify the decoding parameters. |
""" |
do_sample: Optional[bool] = field( |
default=True, |
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."} |
) |
temperature: Optional[float] = field( |
default=0.95, |
metadata={"help": "The value used to modulate the next token probabilities."} |
) |
top_p: Optional[float] = field( |
default=0.7, |
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."} |
) |
top_k: Optional[int] = field( |
default=50, |
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."} |
) |
num_beams: Optional[int] = field( |
default=1, |
metadata={"help": "Number of beams for beam search. 1 means no beam search."} |
) |
max_new_tokens: Optional[int] = field( |
default=512, |
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."} |
) |
repetition_penalty: Optional[float] = field( |
default=1.0, |
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."} |
) |
def to_dict(self) -> Dict[str, Any]: |
return asdict(self) |