from dataclasses import dataclass, field import torch from accelerate import PartialState from callbacks import PerplexityCallback from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments from transformers.trainer_utils import get_last_checkpoint from trl import DPOTrainer, ModelConfig from trl.trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config @dataclass class DPOScriptArguments: dataset_name: str = field(default=None, metadata={"help": "the dataset name"}) dataset_train_split: str = field(default="train", metadata={"help": "the name of the training set of the dataset"}) dataset_eval_split: str = field(default="test", metadata={"help": "the name of the training set of the dataset"}) eval_dataset_name: str = field(default=None, metadata={"help": "the dataset name"}) beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) max_length: int = field(default=512, metadata={"help": "max length of each sample"}) max_prompt_length: int = field(default=128, metadata={"help": "max length of each sample's prompt"}) max_target_length: int = field( default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"} ) sanity_check: bool = field(default=False, metadata={"help": "only train on 1000 samples"}) ignore_bias_buffers: bool = field( default=False, metadata={ "help": "debug argument for distributed training;" "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" }, ) generate_during_eval: bool = field(default=False, metadata={"help": "Generate during evaluation"}) gradient_checkpointing_use_reentrant: bool = field( default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"} ) if __name__ == "__main__": parser = HfArgumentParser((DPOScriptArguments, TrainingArguments, ModelConfig)) args, training_args, model_config = parser.parse_args_into_dataclasses() if training_args.gradient_checkpointing: training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) ################ # Model & Tokenizer ################ torch_dtype = ( model_config.torch_dtype if model_config.torch_dtype in ["auto", None] else getattr(torch, model_config.torch_dtype) ) quantization_config = get_quantization_config(model_config) model_kwargs = dict( revision=model_config.model_revision, trust_remote_code=model_config.trust_remote_code, attn_implementation=model_config.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs) peft_config = get_peft_config(model_config) if peft_config is None: model_ref = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs) else: model_ref = None tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id if args.ignore_bias_buffers: # torch distributed hack model._ddp_params_and_buffers_to_ignore = [ name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool ] ################ # Dataset ################ train_dataset = load_dataset(args.dataset_name, split=args.dataset_train_split) eval_dataset_name = args.eval_dataset_name if args.eval_dataset_name is not None else args.dataset_name eval_dataset = load_dataset(eval_dataset_name, split=args.dataset_eval_split) if args.sanity_check: train_dataset = train_dataset.select(range(50)) eval_dataset = eval_dataset.select(range(50)) ################ # Training ################ trainer = DPOTrainer( model, model_ref, args=training_args, tokenizer=tokenizer, beta=args.beta, train_dataset=train_dataset, eval_dataset=eval_dataset, max_length=args.max_length, max_target_length=args.max_target_length, max_prompt_length=args.max_prompt_length, generate_during_eval=args.generate_during_eval, peft_config=get_peft_config(model_config), ) callback = PerplexityCallback( args=training_args, dataset=eval_dataset, tokenizer=tokenizer, accelerator=trainer.accelerator, max_length=args.max_length, max_prompt_length=args.max_prompt_length, prompt_field="prompt", target_field="chosen", hub_model_id=training_args.hub_model_id, ) trainer.add_callback(callback) last_checkpoint = get_last_checkpoint(training_args.output_dir) trainer.train(resume_from_checkpoint=last_checkpoint) trainer.save_model(training_args.output_dir) if PartialState().is_main_process: # model = trainer.model.merge_and_unload() trainer.push_to_hub(training_args.hub_model_id) tokenizer.push_to_hub(training_args.hub_model_id)