Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Chameleon License found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
import torch | |
from transformers import LogitsProcessor | |
class TopPProbabilityProcessor(LogitsProcessor): | |
# Modified version of TopPLogitsWarper to act on probabilities. | |
# Changes: | |
# * filter_value changed from -inf to 0 | |
# * removed softmax | |
# * renormalize L1 | |
def __init__( | |
self, | |
top_p: float, | |
min_tokens_to_keep: int = 1, | |
): | |
top_p = float(top_p) | |
if top_p < 0 or top_p > 1.0: | |
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") | |
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1): | |
raise ValueError( | |
f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}" | |
) | |
self.top_p = top_p | |
self.min_tokens_to_keep = min_tokens_to_keep | |
def __call__( | |
self, input_ids: torch.LongTensor, probs: torch.FloatTensor | |
) -> torch.FloatTensor: | |
# input_ids.shape=[batch, seq-len] | |
# probs.shape=[batch, vocab] | |
sorted_probs, sorted_indices = torch.sort(probs, descending=False) | |
cumulative_probs = sorted_probs.cumsum(dim=-1) | |
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept) | |
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p) | |
# Keep at least min_tokens_to_keep | |
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 | |
# scatter sorted tensors to original indexing | |
indices_to_remove = sorted_indices_to_remove.scatter( | |
1, sorted_indices, sorted_indices_to_remove | |
) | |
probs = probs.masked_fill(indices_to_remove, 0.0) | |
probs = probs / probs.sum(dim=-1, keepdim=True) | |
return probs | |
class DisallowTokensInIndexRangeLogitsProcessor(LogitsProcessor): | |
def __init__( | |
self, token_ids: list[int], start_index: int, end_index: int | None = None | |
): | |
self.token_ids = torch.tensor(token_ids) | |
self.start_index = start_index | |
self.end_index = end_index if end_index is not None else math.inf | |
def __call__( | |
self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
) -> torch.FloatTensor: | |
current_index = input_ids.shape[1] | |
if self.start_index <= current_index < self.end_index: | |
logits[:, self.token_ids] = -math.inf | |
return logits | |
class DisallowTokensLogitsProcessor(DisallowTokensInIndexRangeLogitsProcessor): | |
def __init__(self, token_ids: list[int]): | |
super().__init__(token_ids, 0) | |
class DisallowTokensAtIndexLogitsProcessor(DisallowTokensInIndexRangeLogitsProcessor): | |
def __init__(self, token_ids: list[int], index: int): | |
super().__init__(token_ids, index, index + 1) | |
class DisallowTokensAfterIndexLogitsProcessor( | |
DisallowTokensInIndexRangeLogitsProcessor | |
): | |
def __init__(self, token_ids: list[int], index: int): | |
super().__init__(token_ids, index + 1) | |
class DisallowTokensAtOrAfterIndexLogitsProcessor( | |
DisallowTokensInIndexRangeLogitsProcessor | |
): | |
def __init__(self, token_ids: list[int], index: int): | |
super().__init__(token_ids, index) | |
class DisallowTokensInBatchIndexRangeLogitsProcessor(LogitsProcessor): | |
def __init__( | |
self, | |
token_ids: list[int], | |
start_indices: list[int], | |
end_indices: list[int] | None = None, | |
): | |
self.token_ids = torch.tensor(token_ids) | |
self.start_indices = torch.tensor(start_indices) | |
self.end_indices = ( | |
torch.tensor(end_indices) | |
if end_indices is not None | |
else torch.full_like(self.start_indices, math.inf, dtype=torch.float) | |
) | |
def __call__( | |
self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
) -> torch.FloatTensor: | |
# input_ids.shape = [batch, seq_len] | |
# logits.shape = [batch, vocab] | |
current_index = input_ids.shape[1] | |
mask = (self.start_indices <= current_index) & ( | |
current_index < self.end_indices | |
) | |
# The following will fail if the mask is all False. | |
# logits[mask, self.token_ids] = -math.inf | |
logits[torch.where(mask)[0].unsqueeze(1), self.token_ids] = -math.inf | |
return logits | |
class DisallowTokensAtBatchIndexLogitsProcessor( | |
DisallowTokensInBatchIndexRangeLogitsProcessor | |
): | |
def __init__(self, token_ids: list[int], batch_index: list[int]): | |
super().__init__(token_ids, batch_index, [i + 1 for i in batch_index]) | |
class AllowOnlyTokensInIndexRangeLogitsProcessor(LogitsProcessor): | |
def __init__( | |
self, token_ids: list[int], start_index: int, end_index: int | None = None | |
): | |
self.token_ids = torch.tensor(token_ids) | |
self.start_index = start_index | |
self.end_index = end_index if end_index is not None else math.inf | |
def __call__( | |
self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
) -> torch.FloatTensor: | |
current_index = input_ids.shape[1] | |
if self.start_index <= current_index < self.end_index: | |
replacement = torch.full_like(logits, -math.inf) | |
replacement[:, self.token_ids] = logits[:, self.token_ids] | |
logits[:] = replacement | |
return logits | |
class AllowOnlyTokensLogitsProcessor(AllowOnlyTokensInIndexRangeLogitsProcessor): | |
def __init__(self, token_ids: list[int]): | |
super().__init__(token_ids, 0) | |
class AllowOnlyTokensAtIndexLogitsProcessor(AllowOnlyTokensInIndexRangeLogitsProcessor): | |
def __init__(self, token_ids: list[int], index: int): | |
super().__init__(token_ids, index, index + 1) | |
class AllowOnlyTokensAfterIndexLogitsProcessor( | |
AllowOnlyTokensInIndexRangeLogitsProcessor | |
): | |
def __init__(self, token_ids: list[int], index: int): | |
super().__init__(token_ids, index + 1) | |
class AllowOnlyTokensAtOrAfterIndexLogitsProcessor( | |
AllowOnlyTokensInIndexRangeLogitsProcessor | |
): | |
def __init__(self, token_ids: list[int], index: int): | |
super().__init__(token_ids, index) | |
class AllowOnlyTokensInBatchIndexRangeLogitsProcessor(LogitsProcessor): | |
def __init__( | |
self, | |
token_ids: list[int], | |
start_indices: list[int], | |
end_indices: list[int] | None = None, | |
): | |
self.token_ids = torch.tensor(token_ids) | |
self.start_indices = torch.tensor(start_indices) | |
self.end_indices = ( | |
torch.tensor(end_indices) | |
if end_indices is not None | |
else torch.full_like(self.start_indices, math.inf, dtype=torch.float) | |
) | |
def __call__( | |
self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
) -> torch.FloatTensor: | |
# input_ids.shape = [batch, seq_len] | |
# logits.shape = [batch, vocab] | |
current_index = input_ids.shape[1] | |
mask = (self.start_indices <= current_index) & ( | |
current_index < self.end_indices | |
) | |
valid_batch_indices = torch.where(mask)[0].unsqueeze(1) | |
full_mask = torch.full_like(logits, -math.inf) | |
full_mask[valid_batch_indices, self.token_ids] = logits[ | |
valid_batch_indices, self.token_ids | |
] | |
logits[:] = torch.where(full_mask != -math.inf, full_mask, logits) | |
return logits | |
class AllowOnlyTokensAtRelativeOffsetLogitsProcessor(LogitsProcessor): | |
def __init__( | |
self, trigger_token_id: int, subsequent_token_ids: list[int], offset: int | |
): | |
self.trigger_token_id = trigger_token_id | |
self.subsequent_token_ids = torch.tensor(subsequent_token_ids) | |
self.offset = offset | |
def __call__( | |
self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
) -> torch.FloatTensor: | |
# input_ids.shape=[batch, seq_len] | |
# logits.shape=[batch, vocab] | |
if input_ids.shape[1] < self.offset: | |
return logits | |
trigger_positions = ( | |
input_ids[:, -self.offset] == self.trigger_token_id | |
).unsqueeze(-1) | |
disallowed_tokens_mask = torch.ones_like(logits, dtype=bool) | |
disallowed_tokens_mask[:, self.subsequent_token_ids] = False | |
return logits.masked_fill_( | |
disallowed_tokens_mask & trigger_positions, | |
-math.inf, | |
) | |
class AllowOnlyTokensInRelativeWindowLogitsProcessor(LogitsProcessor): | |
def __init__(self, trigger_token_id: int, allowed_token_ids: list[int], width: int): | |
self.trigger_token_id = trigger_token_id | |
self.allowed_token_ids = torch.tensor(allowed_token_ids).unsqueeze( | |
0 | |
) # shape: [1, num_allowed_tokens] | |
self.width = width | |
def __call__( | |
self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
) -> torch.FloatTensor: | |
# input_ids.shape=[batch, seq_len] | |
# logits.shape=[batch, vocab] | |
width = min(self.width, input_ids.shape[1]) | |
trigger_positions = ( | |
(input_ids[:, -width:] == self.trigger_token_id).any(dim=1).unsqueeze(-1) | |
) | |
disallowed_tokens_mask = torch.ones_like(logits, dtype=bool) | |
disallowed_tokens_mask[:, self.allowed_token_ids] = False | |
return logits.masked_fill_( | |
disallowed_tokens_mask & trigger_positions, | |
-math.inf, | |
) | |
class CFGLogitsProcessor(LogitsProcessor): | |
def __init__( | |
self, | |
guidance_scale: float, | |
unconditional_ids: torch.LongTensor, | |
model, | |
): | |
self.guidance_scale = guidance_scale | |
self.unconditional_ids = unconditional_ids | |
self.model = model | |
def __call__( | |
self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
) -> torch.FloatTensor: | |
conditioned_logits = logits | |
self.unconditional_ids = torch.cat( | |
[self.unconditional_ids, input_ids[:, -1:]], dim=1 | |
) | |
unconditioned_outputs = self.model(self.unconditional_ids) | |
unconditioned_logits = unconditioned_outputs[:, -1, :] | |
return ( | |
self.guidance_scale * (conditioned_logits - unconditioned_logits) | |
+ unconditioned_logits | |
) | |
class InBatchCFGLogitsProcessor(LogitsProcessor): | |
def __init__(self, guidance_scale: float): | |
self.guidance_scale = guidance_scale | |
def __call__( | |
self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
) -> torch.FloatTensor: | |
# input_ids.shape=[2*batch, seq-len] | |
# logits.shape=[2*batch, vocab] | |
conditioned_logits, unconditioned_logits = torch.chunk(logits, chunks=2, dim=0) | |
mixed_logits = unconditioned_logits + self.guidance_scale * ( | |
conditioned_logits - unconditioned_logits | |
) | |
return mixed_logits.repeat(2, 1) | |
class InBatchInstructCFGLogitsProcessor(LogitsProcessor): | |
# See https://arxiv.org/abs/2211.09800 | |
def __init__(self, guidance_scale_text: float, guidance_scale_image: float): | |
self.guidance_scale_text = guidance_scale_text | |
self.guidance_scale_image = guidance_scale_image | |
def __call__( | |
self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
) -> torch.FloatTensor: | |
# input_ids.shape=[3*batch, seq-len] | |
# logits.shape=[3*batch, vocab] | |
( | |
full_conditioned_logits, | |
image_conditioned_logits, | |
unconditioned_logits, | |
) = logits.chunk(3) | |
mixed_logits = ( | |
unconditioned_logits | |
+ self.guidance_scale_image | |
* (image_conditioned_logits - unconditioned_logits) | |
+ self.guidance_scale_text | |
* (full_conditioned_logits - image_conditioned_logits) | |
) | |
return mixed_logits.repeat(3, 1) | |