workspace / reward_modeling.py
Penghaoo's picture
End of training
4d3e798 verified
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
python examples/scripts/reward_modeling.py \
--model_name_or_path=facebook/opt-350m \
--output_dir="reward_modeling_anthropic_hh" \
--per_device_train_batch_size=16 \
--num_train_epochs=1 \
--gradient_accumulation_steps=2 \
--gradient_checkpointing=True \
--learning_rate=1.41e-5 \
--report_to="wandb" \
--remove_unused_columns=False \
--optim="adamw_torch" \
--logging_steps=10 \
--eval_strategy="steps" \
--eval_steps=500 \
--max_length=512 \
"""
import warnings
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser
from trl import ModelConfig, RewardConfig, RewardTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
from dataclasses import dataclass, field
from transformers import TrainingArguments
print('imported')
@dataclass
class DatasetConfig:
reedsy_dataset: str = field(default=True, metadata={"help": "Path to the Reedsy dataset"})
datapath: str = field(default=None, metadata={"help": "Path to the dataset"})
pairpath: str = field(default=None, metadata={"help": "Path to the story pairs"})
split_by: str = field(default="random", metadata={"help": "How to split the dataset"})
dt_mode: str = field(default="m3", metadata={"help": "DT mode"})
dt_margin: bool = field(default=False, metadata={"help": "DT margin flag"})
time_window: int = field(default=3600, metadata={"help": "Time window for DT"})
used_dataset_size: int = field(default=-1, metadata={"help": "Size of the dataset to use"})
tqdm.pandas()
if __name__ == "__main__":
parser = HfArgumentParser((RewardConfig, ModelConfig, DatasetConfig))
config, model_config, dataset_config = parser.parse_args_into_dataclasses()
config.gradient_checkpointing_kwargs = dict(use_reentrant=False)
################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
)
model = AutoModelForSequenceClassification.from_pretrained(
model_config.model_name_or_path, num_labels=1, trust_remote_code=model_config.trust_remote_code, **model_kwargs
)
if model_config.lora_task_type != "SEQ_CLS":
warnings.warn(
"You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
" Make sure to pass --lora_task_type SEQ_CLS when using this script."
)
################
# Dataset
################
if not dataset_config.reedsy_dataset:
raw_datasets = load_dataset(dataset_config.dataset_name)
train_dataset = raw_datasets[dataset_config.dataset_train_split]
eval_dataset = raw_datasets[dataset_config.dataset_test_split]
else:
from dataloader import StoryPairDataset
SPdataloader = StoryPairDataset(dataset_config.datapath,
dataset_config.pairpath,
tokenizer,
task='rm',
used_dataset_size=dataset_config.used_dataset_size,
train_test_split=0.1,
split_by=dataset_config.split_by,
max_len=4096,
mode= dataset_config.dt_mode,
max_time_window=dataset_config.time_window,
least_likes= 10,
margin=dataset_config.dt_margin)
print('dataset ready')
def preprocess_function(examples):
chosen_text = examples['chosen_text']
rejected_text = examples['rejected_text']
tokenized_input_chosen = tokenizer(chosen_text, truncation=True)
tokenized_input_rejected = tokenizer(rejected_text, truncation=True)
examples['input_ids_chosen'] = tokenized_input_chosen['input_ids']
examples['attention_mask_chosen'] = tokenized_input_chosen['attention_mask']
examples['input_ids_rejected'] = tokenized_input_rejected['input_ids']
examples['attention_mask_rejected'] = tokenized_input_rejected['attention_mask']
return examples
train_dataset = SPdataloader.dataset['train'].map(preprocess_function,num_proc=32)
eval_dataset = SPdataloader.dataset['test'].map(preprocess_function,num_proc=32)
# Preprocess the dataset and filter out examples that are longer than args.max_length
# raw_datasets = raw_datasets.map(
# preprocess_function,
# batched=True,
# num_proc=4,
# )
# train_dataset = dataloader.dataset['train'].map(preprocess_function,num_proc=32)
# eval_dataset = dataloader.dataset['test'].map(preprocess_function,num_proc=32)
print('dataset ready')
#print('one example:', train_dataset[0])
################
# Training
################
trainer = RewardTrainer(
model=model,
tokenizer=tokenizer,
args=config,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=get_peft_config(model_config),
)
trainer.train()
saving_path = '/workspace/RMmodels/' + model_config.model_name_or_path.split('/')[-1] + str(dataset_config.time_window)
trainer.save_model(saving_path)
trainer.push_to_hub()
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
print(metrics)