|
import torch |
|
from torch import nn as nn |
|
from torch.nn import functional as F |
|
|
|
|
|
RESOLUTION_PATCH_NUMS_MAPPING = { |
|
256: "1_2_3_4_5_6_8_10_13_16", |
|
512: "1_2_3_4_6_9_13_18_24_32", |
|
1024: "1_2_3_4_5_7_9_12_16_21_27_36_48_64", |
|
} |
|
|
|
|
|
def sample_with_top_k_top_p_( |
|
logits_BlV: torch.Tensor, |
|
top_k: int = 0, |
|
top_p: float = 0.0, |
|
rng=None, |
|
num_samples=1, |
|
) -> torch.Tensor: |
|
B, l, V = logits_BlV.shape |
|
if top_k > 0: |
|
idx_to_remove = logits_BlV < logits_BlV.topk( |
|
top_k, largest=True, sorted=False, dim=-1 |
|
)[0].amin(dim=-1, keepdim=True) |
|
logits_BlV.masked_fill_(idx_to_remove, -torch.inf) |
|
if top_p > 0: |
|
sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False) |
|
sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p) |
|
sorted_idx_to_remove[..., -1:] = False |
|
logits_BlV.masked_fill_( |
|
sorted_idx_to_remove.scatter( |
|
sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove |
|
), |
|
-torch.inf, |
|
) |
|
|
|
replacement = num_samples >= 0 |
|
num_samples = abs(num_samples) |
|
return torch.multinomial( |
|
logits_BlV.softmax(dim=-1).view(-1, V), |
|
num_samples=num_samples, |
|
replacement=replacement, |
|
generator=rng, |
|
).view(B, l, num_samples) |
|
|
|
|
|
def gumbel_softmax_with_rng( |
|
logits: torch.Tensor, |
|
tau: float = 1, |
|
hard: bool = False, |
|
eps: float = 1e-10, |
|
dim: int = -1, |
|
rng: torch.Generator | None = None, |
|
) -> torch.Tensor: |
|
if rng is None: |
|
return F.gumbel_softmax(logits=logits, tau=tau, hard=hard, eps=eps, dim=dim) |
|
|
|
gumbels = ( |
|
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format) |
|
.exponential_(generator=rng) |
|
.log() |
|
) |
|
gumbels = (logits + gumbels) / tau |
|
y_soft = gumbels.softmax(dim) |
|
|
|
if hard: |
|
index = y_soft.max(dim, keepdim=True)[1] |
|
y_hard = torch.zeros_like( |
|
logits, memory_format=torch.legacy_contiguous_format |
|
).scatter_(dim, index, 1.0) |
|
ret = y_hard - y_soft.detach() + y_soft |
|
else: |
|
ret = y_soft |
|
return ret |
|
|
|
|
|
def drop_path( |
|
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True |
|
): |
|
if drop_prob == 0.0 or not training: |
|
return x |
|
keep_prob = 1 - drop_prob |
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob) |
|
if keep_prob > 0.0 and scale_by_keep: |
|
random_tensor.div_(keep_prob) |
|
return x * random_tensor |
|
|
|
|
|
class DropPath(nn.Module): |
|
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): |
|
super(DropPath, self).__init__() |
|
self.drop_prob = drop_prob |
|
self.scale_by_keep = scale_by_keep |
|
|
|
def forward(self, x): |
|
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) |
|
|
|
def extra_repr(self): |
|
return f"(drop_prob=...)" |
|
|