H2OTest / llm_studio /src /losses /text_dpo_modeling_losses.py
elineve's picture
Upload 301 files
07423df
raw
history blame
5.02 kB
"""
Loss Implementation based upon
https://github.com/eric-mitchell/direct-preference-optimization
"""
import logging
from typing import Any, KeysView
import torch
import torch.nn.functional as F
from torch import nn
__all__ = ["Losses"]
logger = logging.getLogger(__name__)
class DPOLoss(nn.Module):
"""
Implements
"Direct Preference Optimization:
Your Language Model is Secretly a Reward Model"
from https://arxiv.org/abs/2305.18290
"""
def __init__(self, cfg: Any):
super().__init__()
self.cfg = cfg
def forward(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
):
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
losses = self.get_losses(logits=pi_logratios - ref_logratios)
chosen_rewards = (
self.cfg.training.beta
* (policy_chosen_logps - reference_chosen_logps).detach()
)
rejected_rewards = (
self.cfg.training.beta
* (policy_rejected_logps - reference_rejected_logps).detach()
)
return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean()
def get_losses(self, logits):
# The beta is a temperature parameter for the DPO loss,
# typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0.
# The label_smoothing parameter encodes our uncertainty about the labels and
# calculates a conservative DPO loss.
# For now, set label_smoothing to 0 (original DPO loss).
# See https://ericmitchell.ai/cdpo.pdf for more details
label_smoothing = 0
losses = (
-F.logsigmoid(self.cfg.training.beta * logits) * (1 - label_smoothing)
- F.logsigmoid(-self.cfg.training.beta * logits) * label_smoothing
)
return losses
class KTOPairLoss(nn.Module):
"""
Implements original paired KTO implementation
Adopted from https://github.com/ContextualAI/HALOs
and https://github.com/huggingface/trl
"""
def __init__(self, cfg: Any):
super().__init__()
self.cfg = cfg
def forward(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
):
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
rejected_KL = (
(policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)
)
chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps
losses = torch.cat(
(
1
- F.sigmoid(self.cfg.training.beta * (chosen_logratios - rejected_KL)),
1
- F.sigmoid(self.cfg.training.beta * (chosen_KL - rejected_logratios)),
),
0,
)
chosen_rewards = (
self.cfg.training.beta
* (policy_chosen_logps - reference_chosen_logps).detach()
).float()
rejected_rewards = (
self.cfg.training.beta
* (policy_rejected_logps - reference_rejected_logps).detach()
).float()
return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean()
class HingeLoss(DPOLoss):
def get_losses(self, logits):
losses = torch.relu(1 - self.cfg.training.beta * logits)
return losses
class IPOLoss(DPOLoss):
"""
Implements "A General Theoretical Paradigm
to Understand Learning from Human Preferences"
from https://arxiv.org/pdf/2310.12036.pdf
"""
def get_losses(self, logits):
# eqn (17) of the https://arxiv.org/pdf/2310.12036.pdf
# where beta is the real, positive KL parameter for the IPO loss,
# denoted by tau in the paper (see also eqn (6)).
losses = (logits - 1 / (2 * self.cfg.training.beta)) ** 2
return losses
class Losses:
"""Losses factory."""
_losses = {
"DPOLoss": DPOLoss,
"HingeLoss": HingeLoss,
"IPOLoss": IPOLoss,
"KTOPairLoss": KTOPairLoss,
}
@classmethod
def names(cls) -> KeysView:
return cls._losses.keys()
@classmethod
def get(cls, name: str) -> Any:
"""Access to Losses.
Args:
name: losses name
Returns:
A class to build the Losses
"""
return cls._losses.get(name, DPOLoss)
# see https://github.com/huggingface/trl/commit/29d439a2043edf4455b05cae5a1e2ade69d22794
LOSS_REDUCTION = {
"DPOLoss": False,
"KTOPairLoss": False,
"HingeLoss": True,
"IPOLoss": True,
}