File size: 1,613 Bytes
38c3084 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import torch
from torch import nn
class ClsToken(nn.Module):
def __init__(self, ndim: int,
num_tokens: int = 1,
enabled: bool = True,
register_multiple: int = 0,
):
super().__init__()
self.ndim = ndim
self.enabled = enabled
self.num_registers = 0
self.num_tokens = num_tokens
if enabled:
if register_multiple > 0:
self.num_registers = register_multiple - (num_tokens % register_multiple)
scale = ndim ** -0.5
self.token = nn.Parameter(torch.randn(num_tokens + self.num_registers, ndim) * scale)
else:
self.token = None
self.num_patches = self.num_tokens + self.num_registers
def disable(self):
self.token = None
self.enabled = False
def forward(self, x: torch.Tensor):
if self.token is None:
return x
token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1)
x = torch.cat([
token,
x,
], dim=1)
return x
def no_weight_decay(self):
return [
'token',
]
|