# Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ python examples/scripts/reward_modeling.py \ --model_name_or_path=facebook/opt-350m \ --output_dir="reward_modeling_anthropic_hh" \ --per_device_train_batch_size=16 \ --num_train_epochs=1 \ --gradient_accumulation_steps=2 \ --gradient_checkpointing=True \ --learning_rate=1.41e-5 \ --report_to="wandb" \ --remove_unused_columns=False \ --optim="adamw_torch" \ --logging_steps=10 \ --eval_strategy="steps" \ --eval_steps=500 \ --max_length=512 \ """ import warnings import torch from datasets import load_dataset from tqdm import tqdm from transformers import AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser from trl import ModelConfig, RewardConfig, RewardTrainer, get_kbit_device_map, get_peft_config, get_quantization_config from dataclasses import dataclass, field from transformers import TrainingArguments print('imported') @dataclass class DatasetConfig: reedsy_dataset: str = field(default=True, metadata={"help": "Path to the Reedsy dataset"}) datapath: str = field(default=None, metadata={"help": "Path to the dataset"}) pairpath: str = field(default=None, metadata={"help": "Path to the story pairs"}) split_by: str = field(default="random", metadata={"help": "How to split the dataset"}) dt_mode: str = field(default="m3", metadata={"help": "DT mode"}) dt_margin: bool = field(default=False, metadata={"help": "DT margin flag"}) time_window: int = field(default=3600, metadata={"help": "Time window for DT"}) used_dataset_size: int = field(default=-1, metadata={"help": "Size of the dataset to use"}) tqdm.pandas() if __name__ == "__main__": parser = HfArgumentParser((RewardConfig, ModelConfig, DatasetConfig)) config, model_config, dataset_config = parser.parse_args_into_dataclasses() config.gradient_checkpointing_kwargs = dict(use_reentrant=False) ################ # Model & Tokenizer ################ torch_dtype = ( model_config.torch_dtype if model_config.torch_dtype in ["auto", None] else getattr(torch, model_config.torch_dtype) ) quantization_config = get_quantization_config(model_config) model_kwargs = dict( revision=model_config.model_revision, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) tokenizer = AutoTokenizer.from_pretrained( model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True ) model = AutoModelForSequenceClassification.from_pretrained( model_config.model_name_or_path, num_labels=1, trust_remote_code=model_config.trust_remote_code, **model_kwargs ) if model_config.lora_task_type != "SEQ_CLS": warnings.warn( "You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs" " Make sure to pass --lora_task_type SEQ_CLS when using this script." ) ################ # Dataset ################ if not dataset_config.reedsy_dataset: raw_datasets = load_dataset(dataset_config.dataset_name) train_dataset = raw_datasets[dataset_config.dataset_train_split] eval_dataset = raw_datasets[dataset_config.dataset_test_split] else: from dataloader import StoryPairDataset SPdataloader = StoryPairDataset(dataset_config.datapath, dataset_config.pairpath, tokenizer, task='rm', used_dataset_size=dataset_config.used_dataset_size, train_test_split=0.1, split_by=dataset_config.split_by, max_len=4096, mode= dataset_config.dt_mode, max_time_window=dataset_config.time_window, least_likes= 10, margin=dataset_config.dt_margin) print('dataset ready') def preprocess_function(examples): chosen_text = examples['chosen_text'] rejected_text = examples['rejected_text'] tokenized_input_chosen = tokenizer(chosen_text, truncation=True) tokenized_input_rejected = tokenizer(rejected_text, truncation=True) examples['input_ids_chosen'] = tokenized_input_chosen['input_ids'] examples['attention_mask_chosen'] = tokenized_input_chosen['attention_mask'] examples['input_ids_rejected'] = tokenized_input_rejected['input_ids'] examples['attention_mask_rejected'] = tokenized_input_rejected['attention_mask'] return examples train_dataset = SPdataloader.dataset['train'].map(preprocess_function,num_proc=32) eval_dataset = SPdataloader.dataset['test'].map(preprocess_function,num_proc=32) # Preprocess the dataset and filter out examples that are longer than args.max_length # raw_datasets = raw_datasets.map( # preprocess_function, # batched=True, # num_proc=4, # ) # train_dataset = dataloader.dataset['train'].map(preprocess_function,num_proc=32) # eval_dataset = dataloader.dataset['test'].map(preprocess_function,num_proc=32) print('dataset ready') #print('one example:', train_dataset[0]) ################ # Training ################ trainer = RewardTrainer( model=model, tokenizer=tokenizer, args=config, train_dataset=train_dataset, eval_dataset=eval_dataset, peft_config=get_peft_config(model_config), ) trainer.train() saving_path = '/workspace/RMmodels/' + model_config.model_name_or_path.split('/')[-1] + str(dataset_config.time_window) trainer.save_model(saving_path) trainer.push_to_hub() metrics = trainer.evaluate() trainer.log_metrics("eval", metrics) print(metrics)