from dataclasses import dataclass, field |
from typing import Any, Dict, List, Optional, Union |
import bitsandbytes as bnb |
import torch |
from accelerate import Accelerator, DistributedDataParallelKwargs |
from datasets import load_dataset |
from peft import LoraConfig, prepare_model_for_kbit_training |
from torch.utils.data import DataLoader |
from tqdm import tqdm |
from transformers import ( |
AutoModelForSequenceClassification, |
AutoTokenizer, |
BitsAndBytesConfig, |
HfArgumentParser, |
PreTrainedTokenizerBase, |
pipeline, |
) |
import wandb |
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed |
from trl.core import LengthSampler |
from trl.models.modeling_value_adapter import AutoModelForCausalLMWithValueAdapter |
tqdm.pandas() |
@dataclass |
class ScriptArguments: |
""" |
The name of the Casual LM model we wish to fine with PPO |
""" |
model_name: Optional[str] = field(default="", metadata={"help": "the model name"}) |
reward_adapter_name: Optional[str] = field(default="", metadata={"help": "the reward model name"}) |
dataset_name: Optional[str] = field( |
default="CarperAI/openai_summarize_tldr", metadata={"help": "the dataset name"} |
) |
train_split: Optional[str] = field( |
default="train", metadata={"help": "the dataset split to evaluate on; default to 'none' (no evaluation)"} |
) |
eval_split: Optional[str] = field( |
default="train", metadata={"help": "the dataset split to evaluate on; default to 'none' (no evaluation)"} |
) |
log_with: Optional[str] = field(default="wandb", metadata={"help": "use 'wandb' to log with wandb"}) |
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"}) |
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"}) |
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"}) |
ppo_epochs: Optional[int] = field(default=4, metadata={"help": "the number of ppo epochs"}) |
gradient_accumulation_steps: Optional[int] = field( |
default=4, metadata={"help": "the number of gradient accumulation steps"} |
) |
adafactor: Optional[bool] = field(default=False, metadata={"help": "whether to use the adafactor optimizer"}) |
early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"}) |
target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"}) |
reward_baseline: Optional[float] = field( |
default=0.0, |
metadata={"help": "a baseline value that is subtracted from the reward"}, |
) |
batched_gen: Optional[bool] = field(default=False, metadata={"help": "whether to use the batched text gen"}) |
save_steps: Optional[int] = field(default=1000, metadata={"help": "the number of steps to save at"}) |
save_strategy: Optional[str] = field(default="steps") |
output_dir: Optional[str] = field(default="runs/", metadata={"help": "n steps to save the model"}) |
seed: Optional[int] = field(default=0, metadata={"help": "the seed"}) |
steps: Optional[int] = field(default=20000, metadata={"help": "number of epochs"}) |
init_kl_coef: Optional[float] = field( |
default=0.2, |
metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"}, |
) |
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"}) |
value_adapter: Optional[bool] = field(default=False) |
separate_reward_model: Optional[str] = field(default=None, metadata={"help": "the reward model name"}) |
output_min_length: Optional[int] = field(default=24, metadata={"help": "the batch size"}) |
output_max_length: Optional[int] = field(default=48, metadata={"help": "the batch size"}) |
input_max_length: Optional[int] = field(default=512, metadata={"help": "maximum length for generation"}) |
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) |
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) |
bf16: Optional[bool] = field( |
default=False, |
metadata={ |
"help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU." |
}, |
) |
fp16: Optional[bool] = field( |
default=False, |
metadata={ |
"help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU." |
}, |
) |
use_lora: Optional[bool] = field( |
default=True, |
) |
lora_alpha: Optional[float] = field(default=32, metadata={"help": "the lora alpha parameter"}) |
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) |
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) |
lora_all_linear: Optional[bool] = field(default=False, metadata={"help": "lora adapter on all linear layers"}) |
eval_steps: Optional[int] = field(default=None) |
gold_model_name: Optional[str] = field(default=None, metadata={"help": "the reward model name"}) |
gold_in_8bit: Optional[bool] = field(default=False, metadata={"help": "gold the model in 8 bits precision"}) |
gold_in_4bit: Optional[bool] = field(default=False, metadata={"help": "gold the model in 4 bits precision"}) |
gold_bf16: Optional[bool] = field( |
default=False, |
) |
gold_fp16: Optional[bool] = field( |
default=False, |
) |
gold_eval_greedy: Optional[bool] = field(default=True) |
input_ids_input: Optional[bool] = field( |
default=False, |
) |
strip_prompt: Optional[bool] = field( |
default=False, |
) |
just_eval: Optional[bool] = field(default=False) |
@dataclass |
class PromptCollator: |
tokenizer: PreTrainedTokenizerBase |
padding: Union[bool, str] = True |
max_prompt_length: Optional[int] = None |
prompt_field: str = "prompt" |
return_tensors: str = "pt" |
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: |
prompts = [feat[self.prompt_field] for feat in features] |
original_side = self.tokenizer.padding_side |
self.tokenizer.padding_side = "left" |
tokenized_batch = self.tokenizer( |
prompts, |
truncation=True, |
padding=True, |
max_length=self.max_prompt_length, |
return_tensors=self.return_tensors, |
) |
tokenized_batch["prompt"] = prompts |
self.tokenizer.padding_side = original_side |
return tokenized_batch |
def find_all_linear_names(args, model): |
cls = bnb.nn.Linear4bit if args.load_in_4bit else (bnb.nn.Linear8bitLt if args.load_in_8bit else torch.nn.Linear) |
lora_module_names = set() |
for name, module in model.named_modules(): |
if isinstance(module, cls): |
names = name.split(".") |
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) |
if "lm_head" in lora_module_names: |
lora_module_names.remove("lm_head") |
if "score" in lora_module_names: |
lora_module_names.remove("score") |
return list(lora_module_names) |
def create_and_prepare_model(args): |
if args.load_in_8bit and args.load_in_4bit: |
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") |
elif args.load_in_8bit or args.load_in_4bit: |
quantization_config = BitsAndBytesConfig(load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit) |
device_map = {"": Accelerator().local_process_index} |
else: |
device_map = None |
quantization_config = None |
if args.bf16: |
torch_dtype = torch.bfloat16 |
else: |
torch_dtype = torch.float32 |
if script_args.value_adapter: |
model_cls = AutoModelForCausalLMWithValueAdapter |
else: |
model_cls = AutoModelForCausalLMWithValueHead |
if args.use_lora: |
if args.lora_all_linear: |
target_modules = ["dense_h_to_4h", "dense_4h_to_h", "query_key_value", "dense"] |
else: |
target_modules = None |
peft_config = LoraConfig( |
r=args.lora_r, |
lora_alpha=args.lora_alpha, |
lora_dropout=args.lora_dropout, |
bias="none", |
task_type="CAUSAL_LM", |
target_modules=target_modules, |
modules_to_save=["score"], |
) |
else: |
peft_config = None |
model = model_cls.from_pretrained( |
args.model_name, |
quantization_config=quantization_config, |
device_map=device_map, |
torch_dtype=torch_dtype, |
peft_config=peft_config, |
reward_adapter=script_args.reward_adapter_name, |
) |
if quantization_config is not None: |
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) |
args.gradient_checkpointing = False |
model.config.torch_dtype = torch_dtype |
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name) |
if getattr(tokenizer, "pad_token", None) is None: |
tokenizer.pad_token = tokenizer.eos_token |
if getattr(model.config, "pad_token_id", None) is None: |
model.config.pad_token_id = model.config.eos_token_id |
model.eval() |
return model, tokenizer |
def create_and_prepare_dataset(args, tokenizer, split, num_proc=2): |
dataset = load_dataset(args.dataset_name, split=split) |
def strip_prompt(examples): |
examples["prompt"] = [prompt.strip() for prompt in examples["prompt"]] |
return examples |
if args.strip_prompt: |
dataset = dataset.map(strip_prompt, batched=True) |
dataset = dataset.rename_column("prompt", "query") |
original_columns = dataset.column_names |
original_columns.remove("query") |
dataset = dataset.map( |
tokenizer, |
batched=True, |
num_proc=num_proc, |
input_columns="query", |
remove_columns=original_columns, |
fn_kwargs=dict(truncation=True, max_length=args.input_max_length), |
) |
dataset.set_format("torch") |
return dataset |
def collator(data): |
return dict((key, [d[key] for d in data]) for key in data[0]) |
def decode_and_encode(output_token_ids: List[torch.Tensor], tokenizer, max_length, de_and_retokenize=True): |
if de_and_retokenize: |
texts = [q + r for q, r in zip(batch["query"], batch["response"])] |
output_encoding = tokenizer( |
texts, |
padding=True, |
truncation=True, |
return_tensors="pt", |
return_token_type_ids=False, |
max_length=max_length, |
).to(ppo_trainer.accelerator.device) |
else: |
default_padding_side = tokenizer.padding_side |
tokenizer.padding_side = "left" |
full_response_mask = [torch.ones_like(element) for element in output_token_ids] |
full_response_encoding = {"input_ids": output_token_ids, "attention_mask": full_response_mask} |
output_encoding = tokenizer.pad( |
full_response_encoding, |
padding=True, |
max_length=max_length, |
return_tensors="pt", |
) |
tokenizer.padding_side = default_padding_side |
return output_encoding |
def create_and_prepare_gold_model(script_args, accelerator): |
if script_args.gold_in_8bit or script_args.gold_in_4bit: |
gold_quantization_config = BitsAndBytesConfig( |
load_in_8bit=script_args.gold_in_8bit, load_in_4bit=script_args.gold_in_4bit |
) |
gold_device_map = {"": accelerator.local_process_index} |
else: |
gold_device_map = None |
gold_quantization_config = None |
if script_args.gold_bf16: |
torch_dtype = torch.bfloat16 |
elif script_args.gold_fp16: |
torch_dtype = torch.float16 |
else: |
torch_dtype = torch.float32 |
gold_model = AutoModelForSequenceClassification.from_pretrained( |
script_args.gold_model_name, |
quantization_config=gold_quantization_config, |
torch_dtype=torch_dtype, |
device_map=gold_device_map, |
) |
if getattr(gold_model.config, "pad_token_id", None) is None: |
gold_model.config.pad_token_id = gold_model.config.eos_token_id |
gold_model = accelerator.prepare(gold_model) |
gold_model.eval() |
return gold_model |
def create_and_prepare_eval(args, tokenizer, accelerator): |
dataset = load_dataset(args.dataset_name, split=args.eval_split) |
def strip_prompt(examples): |
examples["prompt"] = [prompt.strip() for prompt in examples["prompt"]] |
return examples |
if args.strip_prompt: |
dataset = dataset.map(strip_prompt, batched=True) |
dataloader = DataLoader(dataset, batch_size=args.batch_size) |
return accelerator.prepare(dataloader) |
def get_batch_samples( |
accelerator, model, tokenizer, input_ids, attention_mask, return_ids=False, generation_config=None |
): |
policy_output = model.generate( |
input_ids=input_ids, |
attention_mask=attention_mask, |
generation_config=generation_config, |
) |
with accelerator.unwrap_model(model).disable_adapter(): |
reference_output = model.generate( |
input_ids=input_ids, |
attention_mask=attention_mask, |
generation_config=generation_config, |
) |
policy_output = pad_to_length(policy_output, self.max_length, tokenizer.pad_token_id) |
policy_output_decoded = tokenizer.batch_decode(policy_output, skip_special_tokens=True) |
reference_output = pad_to_length(reference_output, self.max_length, tokenizer.pad_token_id) |
reference_output_decoded = tokenizer.batch_decode(reference_output, skip_special_tokens=True) |
if return_ids: |
return policy_output_decoded, reference_output_decoded, policy_output |
else: |
return policy_output_decoded, reference_output_decoded |
def gold_eval(dataloader, model, gold_model, accelerator, epoch, log_n_samples_during_eval=0): |
samples_to_log = [] |
gold_reward_sum = 0.0 |
total_samples = 0 |
greedy_generation_kwargs = { |
"min_length": -1, |
"top_p": 1.0, |
"do_sample": False, |
"pad_token_id": tokenizer.pad_token_id, |
"eos_token_id": tokenizer.eos_token_id, |
"max_new_tokens": script_args.output_max_length, |
} |
for batch in tqdm( |
dataloader, |
disable=not ppo_trainer.accelerator.is_local_main_process, |
desc="Gold Eval", |
): |
import pdb |
pdb.set_trace() |
full_response_tensors = ppo_trainer.generate( |
batch["input_ids"], |
return_prompt=True, |
**greedy_generation_kwargs, |
) |
response_tensors = [] |
for question, full_response in zip(question_tensors, full_response_tensors): |
response_tensors.append(full_response[len(question) :]) |
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True) |
texts = [q + r for q, r in zip(batch["prompt"], batch["response"])] |
import pdb |
pdb.set_trace() |
policy_output = tokenizer( |
texts, padding=True, truncation=True, return_tensors="pt", return_token_type_ids=False |
).to(ppo_trainer.accelerator.device) |
with torch.no_grad(): |
gold_rewards = gold_model( |
input_ids=policy_output["input_ids"], attention_mask=policy_output["attention_mask"] |
)[0] |
gold_rewards = accelerator.gather_for_metrics(gold_rewards) |
if accelerator.is_local_main_process(): |
gold_reward_sum += gold_rewards.sum().item() |
total_samples += gold_rewards.size(0) |
for i, (prompt, resp) in enumerate(zip(batch["prompt"], batch["response"])): |
if len(samples_to_log) < log_n_samples_during_eval: |
samples_to_log.append([prompt, resp]) |
else: |
break |
if accelerator.is_local_main_process(): |
print(f"gold reward mean {gold_reward_sum / total_samples}") |
gold_log = { |
"eval/gold_rewards_mean": gold_reward_sum / total_samples, |
} |
gold_log["epoch"] = epoch |
if samples_to_log: |
gold_log["game_log"] = ( |
wandb.Table( |
columns=["Prompt", "Policy", "Ref Model"], |
rows=samples_to_log, |
), |
) |
accelerator.log(gold_log) |
return gold_reward_sum / total_samples, samples_to_log |
if __name__ == "__main__": |
parser = HfArgumentParser(ScriptArguments) |
script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0] |
config = PPOConfig( |
steps=script_args.steps, |
model_name=script_args.model_name, |
learning_rate=script_args.learning_rate, |
log_with=script_args.log_with, |
batch_size=script_args.batch_size, |
mini_batch_size=script_args.mini_batch_size, |
gradient_accumulation_steps=script_args.gradient_accumulation_steps, |
optimize_cuda_cache=True, |
early_stopping=script_args.early_stopping, |
target_kl=script_args.target_kl, |
ppo_epochs=script_args.ppo_epochs, |
seed=script_args.seed, |
init_kl_coef=script_args.init_kl_coef, |
adap_kl_ctrl=script_args.adap_kl_ctrl, |
accelerator_kwargs={"kwargs_handlers": [DistributedDataParallelKwargs(find_unused_parameters=False)]}, |
) |
set_seed(config.seed) |
model, tokenizer = create_and_prepare_model(script_args) |
train_dataset = create_and_prepare_dataset(script_args, tokenizer, script_args.train_split) |
ppo_trainer = PPOTrainer( |
config, |
model, |
ref_model=None, |
tokenizer=tokenizer, |
dataset=train_dataset, |
data_collator=collator, |
) |
if script_args.gold_model_name is not None: |
gold_model = create_and_prepare_gold_model(script_args, ppo_trainer.accelerator) |
eval_dataloader = create_and_prepare_eval(script_args, tokenizer, ppo_trainer.accelerator) |
if script_args.just_eval: |
gold_eval( |
eval_dataloader, |
ppo_trainer.model, |
gold_model, |
ppo_trainer.accelerator, |
epoch=0, |
log_n_samples_during_eval=0, |
) |
exit() |
if script_args.separate_reward_model: |
device = ppo_trainer.accelerator.device |
if ppo_trainer.accelerator.num_processes == 1: |
device = 0 if torch.cuda.is_available() else "cpu" |
sentiment_pipe = pipeline( |
"sentiment-analysis", |
model=script_args.separate_reward_model, |
device_map={"": Accelerator().local_process_index}, |
model_kwargs={"load_in_8bit": True}, |
tokenizer=tokenizer, |
return_token_type_ids=False, |
) |
sent_kwargs = { |
"return_all_scores": True, |
"function_to_apply": "none", |
"batch_size": 16, |
"truncation": True, |
} |
generation_kwargs = { |
"min_length": -1, |
"top_k": 0.0, |
"top_p": 1.0, |
"do_sample": True, |
"pad_token_id": tokenizer.pad_token_id, |
"eos_token_id": tokenizer.eos_token_id, |
} |
output_length_sampler = LengthSampler(script_args.output_min_length, script_args.output_max_length) |
for epoch, batch in tqdm( |
enumerate(ppo_trainer.dataloader), |
total=config.total_ppo_epochs, |
disable=not ppo_trainer.accelerator.is_local_main_process, |
): |
if epoch >= config.total_ppo_epochs: |
break |
question_tensors = batch["input_ids"] |
full_response_tensors = ppo_trainer.generate( |
question_tensors, |
return_prompt=True, |
length_sampler=output_length_sampler, |
**generation_kwargs, |
) |
response_tensors = [] |
for question, full_response in zip(question_tensors, full_response_tensors): |
response_tensors.append(full_response[len(question) :]) |
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True) |
texts = [q + r for q, r in zip(batch["query"], batch["response"])] |
policy_output = tokenizer( |
texts, padding=True, truncation=True, return_tensors="pt", return_token_type_ids=False |
).to(ppo_trainer.accelerator.device) |
raw_rewards = ppo_trainer.compute_reward_model_score(**policy_output) |
rewards = [(raw_rewards[i] - script_args.reward_baseline) for i in range(len(raw_rewards))] |
if not script_args.just_eval: |
stats = ppo_trainer.step(question_tensors, response_tensors, rewards) |
else: |
stats = {} |
if script_args.eval_steps is not None and epoch % script_args.eval_steps == 0: |
if script_args.gold_eval_greedy: |
greedy_generation_kwargs = { |
"min_length": -1, |
"top_p": 1.0, |
"do_sample": False, |
"pad_token_id": tokenizer.pad_token_id, |
"eos_token_id": tokenizer.eos_token_id, |
"max_new_tokens": script_args.output_max_length, |
} |
greedy_output = ppo_trainer.generate( |
question_tensors, |
return_prompt=True, |
**greedy_generation_kwargs, |
) |
max_length = script_args.input_max_length + script_args.output_max_length |
policy_output = tokenizer.batch_decode(greedy_output, skip_special_tokens=True) |
with torch.no_grad(): |
gold_rewards = gold_model(**policy_output)[0] |
else: |
gold_rewards = None |
stats["epoch"] = epoch |
ppo_trainer.log_stats(stats, batch, rewards, gold_rewards) |
if script_args.save_strategy != "no" and epoch > 0 and epoch % script_args.save_steps == 0: |
ppo_trainer.save_pretrained(script_args.output_dir + f"step_{epoch}") |