|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
from typing import Dict, Optional |
|
|
|
import torch |
|
from datasets import Dataset, load_dataset |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments |
|
|
|
from trl import DPOTrainer |
|
|
|
|
|
|
|
@dataclass |
|
class ScriptArguments: |
|
""" |
|
The arguments for the DPO training script. |
|
""" |
|
|
|
|
|
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) |
|
|
|
|
|
model_name_or_path: Optional[str] = field(default="gpt2", metadata={"help": "the model name"}) |
|
learning_rate: Optional[float] = field(default=1e-3, metadata={"help": "optimizer learning rate"}) |
|
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "batch size per device"}) |
|
gradient_accumulation_steps: Optional[int] = field( |
|
default=1, metadata={"help": "the number of gradient accumulation steps"} |
|
) |
|
max_length: Optional[int] = field(default=512, metadata={"help": "max length of each sample"}) |
|
max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"}) |
|
max_target_length: Optional[int] = field( |
|
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"} |
|
) |
|
label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"}) |
|
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"}) |
|
|
|
sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"}) |
|
report_to: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' |
|
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' |
|
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' |
|
}, |
|
) |
|
|
|
ignore_bias_buffers: Optional[bool] = field( |
|
default=False, |
|
metadata={ |
|
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" |
|
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" |
|
}, |
|
) |
|
gradient_checkpointing: Optional[bool] = field( |
|
default=False, metadata={"help": "Whether to use gradient checkpointing or no"} |
|
) |
|
gradient_checkpointing_kwargs: Optional[dict] = field( |
|
default=None, |
|
metadata={ |
|
"help": "key word arguments to be passed along `torch.utils.checkpoint.checkpoint` method - e.g. `use_reentrant=False`" |
|
}, |
|
) |
|
|
|
|
|
def extract_anthropic_prompt(prompt_and_response): |
|
"""Extract the anthropic prompt from a prompt and response pair.""" |
|
search_term = "\n\nAssistant:" |
|
search_term_idx = prompt_and_response.rfind(search_term) |
|
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'" |
|
return prompt_and_response[: search_term_idx + len(search_term)] |
|
|
|
|
|
def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset: |
|
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format. |
|
|
|
The dataset is converted to a dictionary with the following structure: |
|
{ |
|
'prompt': List[str], |
|
'chosen': List[str], |
|
'rejected': List[str], |
|
} |
|
|
|
Prompts should be structured as follows: |
|
\n\nHuman: <prompt>\n\nAssistant: |
|
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:. |
|
""" |
|
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir) |
|
if sanity_check: |
|
dataset = dataset.select(range(min(len(dataset), 1000))) |
|
|
|
def split_prompt_and_responses(sample) -> Dict[str, str]: |
|
prompt = extract_anthropic_prompt(sample["chosen"]) |
|
return { |
|
"prompt": prompt, |
|
"chosen": sample["chosen"][len(prompt) :], |
|
"rejected": sample["rejected"][len(prompt) :], |
|
} |
|
|
|
return dataset.map(split_prompt_and_responses) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = HfArgumentParser(ScriptArguments) |
|
script_args = parser.parse_args_into_dataclasses()[0] |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path) |
|
|
|
if script_args.ignore_bias_buffers: |
|
|
|
model._ddp_params_and_buffers_to_ignore = [ |
|
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool |
|
] |
|
|
|
model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
train_dataset = get_hh("train", sanity_check=script_args.sanity_check) |
|
|
|
|
|
eval_dataset = get_hh("test", sanity_check=script_args.sanity_check) |
|
|
|
|
|
training_args = TrainingArguments( |
|
per_device_train_batch_size=script_args.per_device_train_batch_size, |
|
max_steps=script_args.max_steps, |
|
remove_unused_columns=False, |
|
gradient_accumulation_steps=script_args.gradient_accumulation_steps, |
|
learning_rate=script_args.learning_rate, |
|
evaluation_strategy="steps", |
|
logging_first_step=True, |
|
logging_steps=10, |
|
eval_steps=500, |
|
output_dir="./test", |
|
optim="rmsprop", |
|
warmup_steps=150, |
|
report_to=script_args.report_to, |
|
bf16=True, |
|
gradient_checkpointing=script_args.gradient_checkpointing, |
|
|
|
|
|
) |
|
|
|
|
|
dpo_trainer = DPOTrainer( |
|
model, |
|
model_ref, |
|
args=training_args, |
|
beta=script_args.beta, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
tokenizer=tokenizer, |
|
max_length=script_args.max_length, |
|
max_target_length=script_args.max_target_length, |
|
max_prompt_length=script_args.max_prompt_length, |
|
generate_during_eval=True, |
|
) |
|
|
|
|
|
dpo_trainer.train() |
|
|