|
import os |
|
import math |
|
import torch |
|
from tqdm import tqdm |
|
from typing import Callable, Dict, List, Literal, Optional, Tuple |
|
|
|
from transformers import Seq2SeqTrainingArguments, TrainerState |
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead |
|
from trl.core import LengthSampler |
|
|
|
from .peft_trainer import PeftTrainer, LogCallback |
|
|
|
from .config import FinetuningArguments |
|
|
|
from .other import ( |
|
AverageMeter, |
|
get_logger, |
|
get_logits_processor |
|
) |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None: |
|
if target == "reward": |
|
valuehead_state_dict = model.v_head.state_dict() |
|
|
|
setattr(model, "origin_head_weight", valuehead_state_dict["summary.weight"]) |
|
setattr(model, "origin_head_bias", valuehead_state_dict["summary.bias"]) |
|
|
|
model.pretrained_model.set_adapter(target) |
|
model.v_head.load_state_dict({ |
|
"summary.weight": getattr(model, "{}_head_weight".format(target)), |
|
"summary.bias": getattr(model, "{}_head_bias".format(target)) |
|
}) |
|
|
|
|
|
def cast_layernorm_dtype( |
|
model: AutoModelForCausalLMWithValueHead, |
|
layer_norm_names: List[str] = ["norm", "ln_f"], |
|
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None |
|
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: |
|
|
|
layer_norm_state_dict = {} |
|
|
|
for name, param in model.named_parameters(): |
|
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): |
|
if layer_norm_params is not None: |
|
param.data = layer_norm_params[name] |
|
else: |
|
layer_norm_state_dict[name] = param.data.detach().clone() |
|
param.data = param.data.to(torch.float16) |
|
|
|
return model, layer_norm_state_dict |
|
|
|
|
|
class PPOPeftTrainer(PPOTrainer, PeftTrainer): |
|
r""" |
|
Inherits PPOTrainer. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
training_args: Seq2SeqTrainingArguments, |
|
finetuning_args: FinetuningArguments, |
|
callbacks: List[LogCallback], |
|
**kwargs |
|
): |
|
PPOTrainer.__init__(self, **kwargs) |
|
self.args = training_args |
|
self.finetuning_args = finetuning_args |
|
self.log_callback = callbacks[0] |
|
self.state = TrainerState() |
|
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) |
|
|
|
def ppo_train(self, max_target_length: int) -> None: |
|
r""" |
|
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. |
|
""" |
|
total_train_batch_size = self.config.batch_size * self.config.gradient_accumulation_steps * self.args.world_size |
|
len_dataloader = len(self.dataloader) |
|
num_steps_per_epoch = max(len_dataloader // self.config.gradient_accumulation_steps, 1) |
|
num_examples = len(self.dataset) |
|
num_train_epochs = self.args.num_train_epochs |
|
max_steps = math.ceil(num_train_epochs * num_steps_per_epoch) |
|
|
|
self.state.max_steps = max_steps |
|
self.state.num_train_epochs = num_train_epochs |
|
self.state.is_local_process_zero = self.is_local_process_zero() |
|
self.state.is_world_process_zero = self.is_world_process_zero() |
|
|
|
if self.is_world_process_zero(): |
|
logger.info("***** Running training *****") |
|
logger.info(f" Num examples = {num_examples}") |
|
logger.info(f" Num Epochs = {num_train_epochs}") |
|
logger.info(f" Instantaneous batch size per device = {self.config.batch_size}") |
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") |
|
logger.info(f" Gradient Accumulation steps = {self.config.gradient_accumulation_steps}") |
|
logger.info(f" Total optimization steps = {max_steps}") |
|
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}") |
|
|
|
|
|
gen_kwargs = { |
|
"top_k": 0.0, |
|
"top_p": 1.0, |
|
"do_sample": True, |
|
"pad_token_id": self.tokenizer.pad_token_id, |
|
"eos_token_id": self.tokenizer.eos_token_id, |
|
"logits_processor": get_logits_processor() |
|
} |
|
output_length_sampler = LengthSampler(max_target_length // 2, max_target_length) |
|
unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model) |
|
|
|
dataiter = iter(self.dataloader) |
|
steps_trained = 0 |
|
loss_meter = AverageMeter() |
|
reward_meter = AverageMeter() |
|
|
|
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero()): |
|
|
|
for _ in range(self.config.gradient_accumulation_steps): |
|
|
|
batch = next(dataiter) |
|
steps_trained += 1 |
|
|
|
unwrapped_model.gradient_checkpointing_disable() |
|
unwrapped_model.config.use_cache = True |
|
|
|
|
|
query_tensors: torch.Tensor = batch["input_ids"] |
|
response_tensors = self.generate(batch, length_sampler=output_length_sampler, return_prompt=False, **gen_kwargs) |
|
|
|
queries: List[torch.Tensor] = [] |
|
responses: List[torch.Tensor] = [] |
|
for i in range(len(query_tensors)): |
|
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0] |
|
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 |
|
queries.append(query_tensors[i, query_length:]) |
|
if response_length < 2: |
|
responses.append(response_tensors.new_empty(2).fill_(self.tokenizer.eos_token_id)) |
|
else: |
|
responses.append(response_tensors[i, :response_length]) |
|
|
|
|
|
replace_model(unwrapped_model, target="reward") |
|
_, _, values = self.model(**self.prepare_model_inputs(queries, responses)) |
|
rewards = [reward for reward in values[:, -1].to(torch.float32)] |
|
replace_model(unwrapped_model, target="default") |
|
|
|
|
|
unwrapped_model.gradient_checkpointing_enable() |
|
unwrapped_model.config.use_cache = False |
|
|
|
stats = self.step(queries, responses, rewards) |
|
|
|
loss_meter.update(stats["ppo/loss/total"], n=len(rewards)) |
|
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) |
|
|
|
if steps_trained == len_dataloader: |
|
dataiter = iter(self.dataloader) |
|
steps_trained = 0 |
|
|
|
if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0: |
|
logs = { |
|
"loss": round(loss_meter.avg, 4), |
|
"reward": round(reward_meter.avg, 4), |
|
"learning_rate": stats["ppo/learning_rate"], |
|
"epoch": round(step / num_steps_per_epoch, 2) |
|
} |
|
print(logs) |
|
logs["step"] = step |
|
self.state.log_history.append(logs) |
|
self.log_callback.on_log(self.args, self.state, None) |
|
loss_meter.reset() |
|
reward_meter.reset() |
|
|
|
if (step+1) % self.args.save_steps == 0: |
|
self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}")) |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
inputs: Dict[str, torch.Tensor], |
|
length_sampler: Optional[Callable] = None, |
|
return_prompt: Optional[bool] = True, |
|
**generation_kwargs, |
|
) -> torch.Tensor: |
|
r""" |
|
Generates model's responses given queries. |
|
|
|
Subclass and override to inject custom behavior. |
|
""" |
|
self.model, layer_norm_params = cast_layernorm_dtype(self.model) |
|
|
|
if length_sampler is not None: |
|
generation_kwargs["max_new_tokens"] = length_sampler() |
|
|
|
unwrapped_model = self.accelerator.unwrap_model(self.model) |
|
|
|
response = unwrapped_model.generate(**inputs, **generation_kwargs) |
|
|
|
|
|
|
|
if unwrapped_model.pretrained_model.generation_config._from_model_config: |
|
unwrapped_model.pretrained_model.generation_config._from_model_config = False |
|
|
|
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params) |
|
|
|
if not return_prompt and not self.is_encoder_decoder: |
|
return response[:, inputs["input_ids"].size(1):] |
|
return response |
|
|
|
def save_model(self, output_dir: Optional[str] = None) -> None: |
|
r""" |
|
Saves model checkpoint. |
|
|
|
Subclass and override to inject custom behavior. |
|
""" |
|
if self.args.should_save: |
|
self._save(output_dir) |
|
|