File size: 11,862 Bytes
7362797
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Chameleon License found in the
# LICENSE file in the root directory of this source tree.

import math

import torch
from transformers import LogitsProcessor


class TopPProbabilityProcessor(LogitsProcessor):
    # Modified version of TopPLogitsWarper to act on probabilities.
    # Changes:
    # * filter_value changed from -inf to 0
    # * removed softmax
    # * renormalize L1

    def __init__(
        self,
        top_p: float,
        min_tokens_to_keep: int = 1,
    ):
        top_p = float(top_p)
        if top_p < 0 or top_p > 1.0:
            raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
        if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
            raise ValueError(
                f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}"
            )

        self.top_p = top_p
        self.min_tokens_to_keep = min_tokens_to_keep

    def __call__(
        self, input_ids: torch.LongTensor, probs: torch.FloatTensor
    ) -> torch.FloatTensor:
        # input_ids.shape=[batch, seq-len]
        # probs.shape=[batch, vocab]
        sorted_probs, sorted_indices = torch.sort(probs, descending=False)
        cumulative_probs = sorted_probs.cumsum(dim=-1)

        # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
        # Keep at least min_tokens_to_keep
        sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(
            1, sorted_indices, sorted_indices_to_remove
        )
        probs = probs.masked_fill(indices_to_remove, 0.0)
        probs = probs / probs.sum(dim=-1, keepdim=True)
        return probs


class DisallowTokensInIndexRangeLogitsProcessor(LogitsProcessor):
    def __init__(
        self, token_ids: list[int], start_index: int, end_index: int | None = None
    ):
        self.token_ids = torch.tensor(token_ids)
        self.start_index = start_index
        self.end_index = end_index if end_index is not None else math.inf

    def __call__(
        self, input_ids: torch.LongTensor, logits: torch.FloatTensor
    ) -> torch.FloatTensor:
        current_index = input_ids.shape[1]
        if self.start_index <= current_index < self.end_index:
            logits[:, self.token_ids] = -math.inf
        return logits


class DisallowTokensLogitsProcessor(DisallowTokensInIndexRangeLogitsProcessor):
    def __init__(self, token_ids: list[int]):
        super().__init__(token_ids, 0)


class DisallowTokensAtIndexLogitsProcessor(DisallowTokensInIndexRangeLogitsProcessor):
    def __init__(self, token_ids: list[int], index: int):
        super().__init__(token_ids, index, index + 1)


class DisallowTokensAfterIndexLogitsProcessor(
    DisallowTokensInIndexRangeLogitsProcessor
):
    def __init__(self, token_ids: list[int], index: int):
        super().__init__(token_ids, index + 1)


class DisallowTokensAtOrAfterIndexLogitsProcessor(
    DisallowTokensInIndexRangeLogitsProcessor
):
    def __init__(self, token_ids: list[int], index: int):
        super().__init__(token_ids, index)


class DisallowTokensInBatchIndexRangeLogitsProcessor(LogitsProcessor):
    def __init__(
        self,
        token_ids: list[int],
        start_indices: list[int],
        end_indices: list[int] | None = None,
    ):
        self.token_ids = torch.tensor(token_ids)
        self.start_indices = torch.tensor(start_indices)
        self.end_indices = (
            torch.tensor(end_indices)
            if end_indices is not None
            else torch.full_like(self.start_indices, math.inf, dtype=torch.float)
        )

    def __call__(
        self, input_ids: torch.LongTensor, logits: torch.FloatTensor
    ) -> torch.FloatTensor:
        # input_ids.shape = [batch, seq_len]
        # logits.shape = [batch, vocab]
        current_index = input_ids.shape[1]
        mask = (self.start_indices <= current_index) & (
            current_index < self.end_indices
        )
        # The following will fail if the mask is all False.
        # logits[mask, self.token_ids] = -math.inf
        logits[torch.where(mask)[0].unsqueeze(1), self.token_ids] = -math.inf
        return logits


class DisallowTokensAtBatchIndexLogitsProcessor(
    DisallowTokensInBatchIndexRangeLogitsProcessor
):
    def __init__(self, token_ids: list[int], batch_index: list[int]):
        super().__init__(token_ids, batch_index, [i + 1 for i in batch_index])


class AllowOnlyTokensInIndexRangeLogitsProcessor(LogitsProcessor):
    def __init__(
        self, token_ids: list[int], start_index: int, end_index: int | None = None
    ):
        self.token_ids = torch.tensor(token_ids)
        self.start_index = start_index
        self.end_index = end_index if end_index is not None else math.inf

    def __call__(
        self, input_ids: torch.LongTensor, logits: torch.FloatTensor
    ) -> torch.FloatTensor:
        current_index = input_ids.shape[1]
        if self.start_index <= current_index < self.end_index:
            replacement = torch.full_like(logits, -math.inf)
            replacement[:, self.token_ids] = logits[:, self.token_ids]
            logits[:] = replacement
        return logits


class AllowOnlyTokensLogitsProcessor(AllowOnlyTokensInIndexRangeLogitsProcessor):
    def __init__(self, token_ids: list[int]):
        super().__init__(token_ids, 0)


class AllowOnlyTokensAtIndexLogitsProcessor(AllowOnlyTokensInIndexRangeLogitsProcessor):
    def __init__(self, token_ids: list[int], index: int):
        super().__init__(token_ids, index, index + 1)


class AllowOnlyTokensAfterIndexLogitsProcessor(
    AllowOnlyTokensInIndexRangeLogitsProcessor
):
    def __init__(self, token_ids: list[int], index: int):
        super().__init__(token_ids, index + 1)


class AllowOnlyTokensAtOrAfterIndexLogitsProcessor(
    AllowOnlyTokensInIndexRangeLogitsProcessor
):
    def __init__(self, token_ids: list[int], index: int):
        super().__init__(token_ids, index)


class AllowOnlyTokensInBatchIndexRangeLogitsProcessor(LogitsProcessor):
    def __init__(
        self,
        token_ids: list[int],
        start_indices: list[int],
        end_indices: list[int] | None = None,
    ):
        self.token_ids = torch.tensor(token_ids)
        self.start_indices = torch.tensor(start_indices)
        self.end_indices = (
            torch.tensor(end_indices)
            if end_indices is not None
            else torch.full_like(self.start_indices, math.inf, dtype=torch.float)
        )

    def __call__(
        self, input_ids: torch.LongTensor, logits: torch.FloatTensor
    ) -> torch.FloatTensor:
        # input_ids.shape = [batch, seq_len]
        # logits.shape = [batch, vocab]
        current_index = input_ids.shape[1]
        mask = (self.start_indices <= current_index) & (
            current_index < self.end_indices
        )

        valid_batch_indices = torch.where(mask)[0].unsqueeze(1)
        full_mask = torch.full_like(logits, -math.inf)
        full_mask[valid_batch_indices, self.token_ids] = logits[
            valid_batch_indices, self.token_ids
        ]

        logits[:] = torch.where(full_mask != -math.inf, full_mask, logits)
        return logits


class AllowOnlyTokensAtRelativeOffsetLogitsProcessor(LogitsProcessor):
    def __init__(
        self, trigger_token_id: int, subsequent_token_ids: list[int], offset: int
    ):
        self.trigger_token_id = trigger_token_id
        self.subsequent_token_ids = torch.tensor(subsequent_token_ids)
        self.offset = offset

    def __call__(
        self, input_ids: torch.LongTensor, logits: torch.FloatTensor
    ) -> torch.FloatTensor:
        # input_ids.shape=[batch, seq_len]
        # logits.shape=[batch, vocab]
        if input_ids.shape[1] < self.offset:
            return logits

        trigger_positions = (
            input_ids[:, -self.offset] == self.trigger_token_id
        ).unsqueeze(-1)

        disallowed_tokens_mask = torch.ones_like(logits, dtype=bool)
        disallowed_tokens_mask[:, self.subsequent_token_ids] = False

        return logits.masked_fill_(
            disallowed_tokens_mask & trigger_positions,
            -math.inf,
        )


class AllowOnlyTokensInRelativeWindowLogitsProcessor(LogitsProcessor):
    def __init__(self, trigger_token_id: int, allowed_token_ids: list[int], width: int):
        self.trigger_token_id = trigger_token_id
        self.allowed_token_ids = torch.tensor(allowed_token_ids).unsqueeze(
            0
        )  # shape: [1, num_allowed_tokens]
        self.width = width

    def __call__(
        self, input_ids: torch.LongTensor, logits: torch.FloatTensor
    ) -> torch.FloatTensor:
        # input_ids.shape=[batch, seq_len]
        # logits.shape=[batch, vocab]
        width = min(self.width, input_ids.shape[1])
        trigger_positions = (
            (input_ids[:, -width:] == self.trigger_token_id).any(dim=1).unsqueeze(-1)
        )

        disallowed_tokens_mask = torch.ones_like(logits, dtype=bool)
        disallowed_tokens_mask[:, self.allowed_token_ids] = False

        return logits.masked_fill_(
            disallowed_tokens_mask & trigger_positions,
            -math.inf,
        )


class CFGLogitsProcessor(LogitsProcessor):
    def __init__(
        self,
        guidance_scale: float,
        unconditional_ids: torch.LongTensor,
        model,
    ):
        self.guidance_scale = guidance_scale
        self.unconditional_ids = unconditional_ids
        self.model = model

    def __call__(
        self, input_ids: torch.LongTensor, logits: torch.FloatTensor
    ) -> torch.FloatTensor:
        conditioned_logits = logits

        self.unconditional_ids = torch.cat(
            [self.unconditional_ids, input_ids[:, -1:]], dim=1
        )
        unconditioned_outputs = self.model(self.unconditional_ids)
        unconditioned_logits = unconditioned_outputs[:, -1, :]
        return (
            self.guidance_scale * (conditioned_logits - unconditioned_logits)
            + unconditioned_logits
        )


class InBatchCFGLogitsProcessor(LogitsProcessor):
    def __init__(self, guidance_scale: float):
        self.guidance_scale = guidance_scale

    def __call__(
        self, input_ids: torch.LongTensor, logits: torch.FloatTensor
    ) -> torch.FloatTensor:
        # input_ids.shape=[2*batch, seq-len]
        # logits.shape=[2*batch, vocab]
        conditioned_logits, unconditioned_logits = torch.chunk(logits, chunks=2, dim=0)
        mixed_logits = unconditioned_logits + self.guidance_scale * (
            conditioned_logits - unconditioned_logits
        )
        return mixed_logits.repeat(2, 1)


class InBatchInstructCFGLogitsProcessor(LogitsProcessor):
    # See https://arxiv.org/abs/2211.09800

    def __init__(self, guidance_scale_text: float, guidance_scale_image: float):
        self.guidance_scale_text = guidance_scale_text
        self.guidance_scale_image = guidance_scale_image

    def __call__(
        self, input_ids: torch.LongTensor, logits: torch.FloatTensor
    ) -> torch.FloatTensor:
        # input_ids.shape=[3*batch, seq-len]
        # logits.shape=[3*batch, vocab]
        (
            full_conditioned_logits,
            image_conditioned_logits,
            unconditioned_logits,
        ) = logits.chunk(3)
        mixed_logits = (
            unconditioned_logits
            + self.guidance_scale_image
            * (image_conditioned_logits - unconditioned_logits)
            + self.guidance_scale_text
            * (full_conditioned_logits - image_conditioned_logits)
        )
        return mixed_logits.repeat(3, 1)