|
import os |
|
import sys |
|
import json |
|
import torch |
|
import logging |
|
from typing import Dict, List, Optional |
|
|
|
from transformers.trainer import TRAINER_STATE_NAME |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.generation.utils import LogitsProcessorList |
|
from transformers.generation.logits_process import LogitsProcessor |
|
|
|
from peft.utils import WEIGHTS_NAME |
|
|
|
|
|
IGNORE_INDEX = -100 |
|
VALUE_HEAD_FILE_NAME = "value_head.bin" |
|
FINETUNING_ARGS_NAME = "finetuning_args.json" |
|
|
|
|
|
def get_logger(name: str) -> logging.Logger: |
|
return logging.getLogger(name) |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
level=logging.INFO, |
|
handlers=[logging.StreamHandler(sys.stdout)] |
|
) |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
class AverageMeter: |
|
r""" |
|
Computes and stores the average and current value. |
|
""" |
|
def __init__(self): |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
|
|
|
|
class InvalidScoreLogitsProcessor(LogitsProcessor): |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
if torch.isnan(scores).any() or torch.isinf(scores).any(): |
|
scores.zero_() |
|
scores[..., 0] = 1.0 |
|
return scores |
|
|
|
|
|
def get_logits_processor() -> LogitsProcessorList: |
|
logits_processor = LogitsProcessorList() |
|
logits_processor.append(InvalidScoreLogitsProcessor()) |
|
return logits_processor |
|
|
|
|
|
|
|
|
|
def prepare_model_for_training( |
|
model: PreTrainedModel, |
|
finetuning_type: str, |
|
output_embedding_layer_name: Optional[str] = "lm_head", |
|
use_gradient_checkpointing: Optional[bool] = True, |
|
layer_norm_names: Optional[List[str]] = ["norm", "ln_f"] |
|
) -> PreTrainedModel: |
|
|
|
for name, param in model.named_parameters(): |
|
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): |
|
param.data = param.data.to(torch.float32) |
|
|
|
if use_gradient_checkpointing: |
|
if hasattr(model, "enable_input_require_grads"): |
|
model.enable_input_require_grads() |
|
else: |
|
def make_inputs_require_grad(module, input, output): |
|
output.requires_grad_(True) |
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
|
|
|
model.gradient_checkpointing_enable() |
|
model.config.use_cache = False |
|
|
|
if finetuning_type != "full" and hasattr(model, output_embedding_layer_name): |
|
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name) |
|
input_dtype = output_embedding_layer.weight.dtype |
|
|
|
class CastOutputToFloat(torch.nn.Sequential): |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return super().forward(x.to(input_dtype)).to(torch.float32) |
|
|
|
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer)) |
|
|
|
return model |
|
|
|
|
|
def print_trainable_params(model: torch.nn.Module) -> None: |
|
trainable_params, all_param = 0, 0 |
|
for param in model.parameters(): |
|
num_params = param.numel() |
|
|
|
if num_params == 0 and hasattr(param, "ds_numel"): |
|
num_params = param.ds_numel |
|
all_param += num_params |
|
if param.requires_grad: |
|
trainable_params += num_params |
|
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( |
|
trainable_params, all_param, 100 * trainable_params / all_param)) |
|
|
|
|
|
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: |
|
state_dict = model.state_dict() |
|
filtered_state_dict = {} |
|
|
|
for k, v in model.named_parameters(): |
|
if v.requires_grad: |
|
filtered_state_dict[k] = state_dict[k].cpu().clone().detach() |
|
|
|
return filtered_state_dict |
|
|
|
|
|
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: |
|
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME) |
|
if not os.path.exists(weights_file): |
|
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir)) |
|
return False |
|
model_state_dict = torch.load(weights_file, map_location="cpu") |
|
model.load_state_dict(model_state_dict, strict=False) |
|
return True |
|
|
|
|
|
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: |
|
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME) |
|
if not os.path.exists(valuehead_file): |
|
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir)) |
|
return False |
|
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu") |
|
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"]) |
|
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"]) |
|
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"])) |
|
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])) |
|
return True |
|
|
|
|
|
def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]: |
|
r""" |
|
EMA implementation according to TensorBoard. |
|
""" |
|
last = scalars[0] |
|
smoothed = list() |
|
for next_val in scalars: |
|
smoothed_val = last * weight + (1 - weight) * next_val |
|
smoothed.append(smoothed_val) |
|
last = smoothed_val |
|
return smoothed |
|
|
|
|
|
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None: |
|
import matplotlib.pyplot as plt |
|
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: |
|
data = json.load(f) |
|
|
|
for key in keys: |
|
steps, metrics = [], [] |
|
for i in range(len(data["log_history"])): |
|
if key in data["log_history"][i]: |
|
steps.append(data["log_history"][i]["step"]) |
|
metrics.append(data["log_history"][i][key]) |
|
|
|
if len(metrics) == 0: |
|
logger.warning(f"No metric {key} to plot.") |
|
continue |
|
|
|
plt.figure() |
|
plt.plot(steps, metrics, alpha=0.4, label="original") |
|
plt.plot(steps, smooth(metrics), label="smoothed") |
|
plt.title("training {} of {}".format(key, save_dictionary)) |
|
plt.xlabel("step") |
|
plt.ylabel(key) |
|
plt.legend() |
|
plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100) |
|
print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key))) |
|
|