File size: 7,342 Bytes
07423df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import logging
from typing import Any, Dict
import torch
from torch import nn
from transformers import AutoModelForCausalLM
from llm_studio.src.losses.text_causal_language_modeling_losses import (
SampleAveragedCrossEntropyLoss,
)
from llm_studio.src.losses.text_dpo_modeling_losses import LOSS_REDUCTION
from llm_studio.src.metrics.text_causal_language_modeling_metrics import Perplexity
from llm_studio.src.utils.data_utils import batch_padding
from llm_studio.src.utils.modeling_utils import (
create_nlp_backbone,
generate,
prepare_lora,
)
logger = logging.getLogger(__name__)
def get_batch_logps(
logits: torch.FloatTensor,
labels: torch.LongTensor,
average_log_prob: bool = False,
) -> torch.Tensor:
"""
Based upon the official implementation of DPO:
https://github.com/eric-mitchell/direct-preference-optimization
Compute the log probabilities of the given labels under the given logits.
Args:
logits:
Logits of the model (unnormalized).
Shape: (batch_size, sequence_length, vocab_size)
labels:
Labels for which to compute the log probabilities.
Label tokens with a value of -100 are ignored.
Shape: (batch_size, sequence_length)
average_log_prob:
If True, return the average log probability per (non-masked) token.
Otherwise, return the sum of the
log probabilities of the (non-masked) tokens.
Returns:
A tensor of shape (batch_size,) containing the average/sum
log probabilities of the given labels under the given logits.
"""
assert logits.shape[:-1] == labels.shape
# shift labels and logits to account for next token prediction
# See also text_causal_language_modeling_losses.py
labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
loss_mask = labels != -100
# dummy token; we'll ignore the losses on these tokens when loss_mask is applied
# Needed to be able to apply torch.gather with index=labels.unsqueeze(2)
labels[labels == -100] = 0
per_token_logps = torch.gather(
logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)
).squeeze(2)
if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
return (per_token_logps * loss_mask).sum(-1)
class Model(nn.Module):
"""
Model for DPO language modeling problem type.
"""
def __init__(self, cfg: Any):
super().__init__()
self.cfg = cfg
self.backbone, self.backbone_config = create_nlp_backbone(
cfg, model_class=AutoModelForCausalLM
)
assert cfg.training.lora, "Need to enable lora for dpo training"
self.backbone = prepare_lora(cfg=cfg, backbone=self.backbone)
self.loss_fn = self.cfg.training.loss_class.get(
self.cfg.training.loss_function
)(self.cfg)
if self.cfg.prediction.metric == "Perplexity":
self.perplexity = Perplexity(self.cfg, reduce=False)
def generate(self, batch: Dict, cfg: Any, streamer=None):
return generate(self.backbone, batch, cfg, streamer)
def forward(
self,
batch: Dict,
padding: bool = True,
) -> Dict:
"""
Forward pass of DPO model.
Runtime is 4 times slower than causal language modeling model
as we need to compute
- logits for chosen answer
- logits for rejected answer
- logits for chosen answer with reference model
- logits for rejected answer with reference model
"""
# disable cache if gradient checkpointing is enabled
if self.cfg.architecture.gradient_checkpointing:
self.backbone.config.use_cache = False
outputs: Dict = {}
logits_dict = {}
labels_dict = {}
for answer in ["chosen", "rejected"]:
if padding:
batch = batch_padding(
self.cfg,
batch,
self.training,
mask_key=f"{answer}_attention_mask",
pad_keys=[
f"{answer}_input_ids",
f"{answer}_attention_mask",
f"{answer}_labels",
],
)
logits = self.backbone(
input_ids=batch[f"{answer}_input_ids"],
attention_mask=batch[f"{answer}_attention_mask"],
).logits
logits_dict[answer] = logits
labels_dict[answer] = batch[f"{answer}_labels"]
outputs[f"{answer}_logps"] = get_batch_logps(
logits,
batch[f"{answer}_labels"],
average_log_prob=LOSS_REDUCTION[self.cfg.training.loss_function],
)
with self.backbone.disable_adapter():
with torch.no_grad():
reference_logits = self.backbone(
input_ids=batch[f"{answer}_input_ids"],
attention_mask=batch[f"{answer}_attention_mask"],
).logits
outputs[f"{answer}_reference_logps"] = get_batch_logps(
reference_logits,
batch[f"{answer}_labels"],
average_log_prob=LOSS_REDUCTION[
self.cfg.training.loss_function
],
)
loss, chosen_rewards, rejected_rewards = self.loss_fn(
policy_chosen_logps=outputs["chosen_logps"],
policy_rejected_logps=outputs["rejected_logps"],
reference_chosen_logps=outputs["chosen_reference_logps"],
reference_rejected_logps=outputs["rejected_reference_logps"],
)
outputs["loss"] = loss
# These values will be logged to Neptune if enabled, see train.py
outputs["additional_log_chosen_rewards"] = chosen_rewards.detach()
outputs["additional_log_rejected_rewards"] = rejected_rewards.detach()
# Reward margin should increase over time
outputs["additional_log_reward_margin"] = (
chosen_rewards - rejected_rewards
).detach()
# log sample average cross entropy, perplexity metric is also sample averaged
outputs["additional_log_chosen_cross_entropy_loss"] = (
SampleAveragedCrossEntropyLoss(self.cfg)(
logits_dict["chosen"], labels_dict["chosen"]
).detach()
)
outputs["additional_log_rejected_cross_entropy_loss"] = (
SampleAveragedCrossEntropyLoss(self.cfg)(
logits_dict["rejected"], labels_dict["rejected"]
).detach()
)
if not self.training and self.cfg.prediction.metric == "Perplexity":
outputs["perplexity"] = self.perplexity(
logits_dict["chosen"], labels_dict["chosen"]
)
outputs["additional_log_rejected_perplexity"] = self.perplexity(
logits_dict["rejected"], labels_dict["rejected"]
)
# enable cache again if gradient checkpointing is enabled
if self.cfg.architecture.gradient_checkpointing:
self.backbone.config.use_cache = True
return outputs
|