File size: 7,342 Bytes
07423df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import logging
from typing import Any, Dict

import torch
from torch import nn
from transformers import AutoModelForCausalLM

from llm_studio.src.losses.text_causal_language_modeling_losses import (
    SampleAveragedCrossEntropyLoss,
)
from llm_studio.src.losses.text_dpo_modeling_losses import LOSS_REDUCTION
from llm_studio.src.metrics.text_causal_language_modeling_metrics import Perplexity
from llm_studio.src.utils.data_utils import batch_padding
from llm_studio.src.utils.modeling_utils import (
    create_nlp_backbone,
    generate,
    prepare_lora,
)

logger = logging.getLogger(__name__)


def get_batch_logps(
    logits: torch.FloatTensor,
    labels: torch.LongTensor,
    average_log_prob: bool = False,
) -> torch.Tensor:
    """
    Based upon the official implementation of DPO:
    https://github.com/eric-mitchell/direct-preference-optimization

    Compute the log probabilities of the given labels under the given logits.
    Args:
        logits:
            Logits of the model (unnormalized).
            Shape: (batch_size, sequence_length, vocab_size)
        labels:
            Labels for which to compute the log probabilities.
            Label tokens with a value of -100 are ignored.
            Shape: (batch_size, sequence_length)
        average_log_prob:
            If True, return the average log probability per (non-masked) token.
            Otherwise, return the sum of the
            log probabilities of the (non-masked) tokens.
    Returns:
        A tensor of shape (batch_size,) containing the average/sum
        log probabilities of the given labels under the given logits.
    """
    assert logits.shape[:-1] == labels.shape

    # shift labels and logits to account for next token prediction
    # See also text_causal_language_modeling_losses.py
    labels = labels[:, 1:].clone()
    logits = logits[:, :-1, :]
    loss_mask = labels != -100

    # dummy token; we'll ignore the losses on these tokens when loss_mask is applied
    # Needed to be able to apply torch.gather with index=labels.unsqueeze(2)
    labels[labels == -100] = 0

    per_token_logps = torch.gather(
        logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)
    ).squeeze(2)

    if average_log_prob:
        return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
    else:
        return (per_token_logps * loss_mask).sum(-1)


class Model(nn.Module):
    """
    Model for DPO language modeling problem type.
    """

    def __init__(self, cfg: Any):
        super().__init__()

        self.cfg = cfg
        self.backbone, self.backbone_config = create_nlp_backbone(
            cfg, model_class=AutoModelForCausalLM
        )

        assert cfg.training.lora, "Need to enable lora for dpo training"
        self.backbone = prepare_lora(cfg=cfg, backbone=self.backbone)

        self.loss_fn = self.cfg.training.loss_class.get(
            self.cfg.training.loss_function
        )(self.cfg)
        if self.cfg.prediction.metric == "Perplexity":
            self.perplexity = Perplexity(self.cfg, reduce=False)

    def generate(self, batch: Dict, cfg: Any, streamer=None):
        return generate(self.backbone, batch, cfg, streamer)

    def forward(
        self,
        batch: Dict,
        padding: bool = True,
    ) -> Dict:
        """
        Forward pass of DPO model.
        Runtime is 4 times slower than causal language modeling model
        as we need to compute
        - logits for chosen answer
        - logits for rejected answer
        - logits for chosen answer with reference model
        - logits for rejected answer with reference model
        """
        # disable cache if gradient checkpointing is enabled
        if self.cfg.architecture.gradient_checkpointing:
            self.backbone.config.use_cache = False

        outputs: Dict = {}

        logits_dict = {}
        labels_dict = {}

        for answer in ["chosen", "rejected"]:
            if padding:
                batch = batch_padding(
                    self.cfg,
                    batch,
                    self.training,
                    mask_key=f"{answer}_attention_mask",
                    pad_keys=[
                        f"{answer}_input_ids",
                        f"{answer}_attention_mask",
                        f"{answer}_labels",
                    ],
                )
            logits = self.backbone(
                input_ids=batch[f"{answer}_input_ids"],
                attention_mask=batch[f"{answer}_attention_mask"],
            ).logits

            logits_dict[answer] = logits
            labels_dict[answer] = batch[f"{answer}_labels"]

            outputs[f"{answer}_logps"] = get_batch_logps(
                logits,
                batch[f"{answer}_labels"],
                average_log_prob=LOSS_REDUCTION[self.cfg.training.loss_function],
            )

            with self.backbone.disable_adapter():
                with torch.no_grad():
                    reference_logits = self.backbone(
                        input_ids=batch[f"{answer}_input_ids"],
                        attention_mask=batch[f"{answer}_attention_mask"],
                    ).logits
                    outputs[f"{answer}_reference_logps"] = get_batch_logps(
                        reference_logits,
                        batch[f"{answer}_labels"],
                        average_log_prob=LOSS_REDUCTION[
                            self.cfg.training.loss_function
                        ],
                    )

        loss, chosen_rewards, rejected_rewards = self.loss_fn(
            policy_chosen_logps=outputs["chosen_logps"],
            policy_rejected_logps=outputs["rejected_logps"],
            reference_chosen_logps=outputs["chosen_reference_logps"],
            reference_rejected_logps=outputs["rejected_reference_logps"],
        )
        outputs["loss"] = loss

        # These values will be logged to Neptune if enabled, see train.py
        outputs["additional_log_chosen_rewards"] = chosen_rewards.detach()
        outputs["additional_log_rejected_rewards"] = rejected_rewards.detach()
        # Reward margin should increase over time
        outputs["additional_log_reward_margin"] = (
            chosen_rewards - rejected_rewards
        ).detach()

        # log sample average cross entropy, perplexity metric is also sample averaged
        outputs["additional_log_chosen_cross_entropy_loss"] = (
            SampleAveragedCrossEntropyLoss(self.cfg)(
                logits_dict["chosen"], labels_dict["chosen"]
            ).detach()
        )
        outputs["additional_log_rejected_cross_entropy_loss"] = (
            SampleAveragedCrossEntropyLoss(self.cfg)(
                logits_dict["rejected"], labels_dict["rejected"]
            ).detach()
        )

        if not self.training and self.cfg.prediction.metric == "Perplexity":
            outputs["perplexity"] = self.perplexity(
                logits_dict["chosen"], labels_dict["chosen"]
            )
            outputs["additional_log_rejected_perplexity"] = self.perplexity(
                logits_dict["rejected"], labels_dict["rejected"]
            )

        # enable cache again if gradient checkpointing is enabled
        if self.cfg.architecture.gradient_checkpointing:
            self.backbone.config.use_cache = True

        return outputs