|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import einops |
|
import numpy as np |
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
|
|
class Normalize(nn.Module): |
|
def __init__(self, dim: int) -> None: |
|
super().__init__() |
|
self.dim = dim |
|
|
|
def forward(self, x): |
|
return torch.nn.functional.normalize(x, dim=self.dim, p=2) |
|
|
|
|
|
class LearnableLogitScaling(nn.Module): |
|
def __init__( |
|
self, |
|
logit_scale_init: float = 1 / 0.07, |
|
learnable: bool = True, |
|
max_logit_scale: float = 100, |
|
) -> None: |
|
super().__init__() |
|
self.max_logit_scale = max_logit_scale |
|
self.logit_scale_init = logit_scale_init |
|
self.learnable = learnable |
|
log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init) |
|
if learnable: |
|
self.log_logit_scale = nn.Parameter(log_logit_scale) |
|
else: |
|
self.register_buffer("log_logit_scale", log_logit_scale) |
|
|
|
def forward(self, x): |
|
return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x |
|
|
|
def extra_repr(self): |
|
st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}" |
|
return st |
|
|
|
|
|
class EinOpsRearrange(nn.Module): |
|
def __init__(self, rearrange_expr: str, **kwargs) -> None: |
|
super().__init__() |
|
self.rearrange_expr = rearrange_expr |
|
self.kwargs = kwargs |
|
|
|
def forward(self, x): |
|
assert isinstance(x, torch.Tensor) |
|
return einops.rearrange(x, self.rearrange_expr, **self.kwargs) |
|
|
|
|
|
class VerboseNNModule(nn.Module): |
|
""" |
|
Wrapper around nn.Module that prints registered buffers and parameter names. |
|
""" |
|
|
|
@staticmethod |
|
def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str: |
|
st = ( |
|
"(" |
|
+ name |
|
+ "): " |
|
+ "tensor(" |
|
+ str(tuple(tensor[1].shape)) |
|
+ ", requires_grad=" |
|
+ str(tensor[1].requires_grad) |
|
+ ")\n" |
|
) |
|
return st |
|
|
|
def extra_repr(self) -> str: |
|
named_modules = set() |
|
for p in self.named_modules(): |
|
named_modules.update([p[0]]) |
|
named_modules = list(named_modules) |
|
|
|
string_repr = "" |
|
for p in self.named_parameters(): |
|
name = p[0].split(".")[0] |
|
if name not in named_modules: |
|
string_repr += self.get_readable_tensor_repr(name, p) |
|
|
|
for p in self.named_buffers(): |
|
name = p[0].split(".")[0] |
|
string_repr += self.get_readable_tensor_repr(name, p) |
|
|
|
return string_repr |
|
|
|
|
|
def cast_if_src_dtype( |
|
tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype |
|
): |
|
updated = False |
|
if tensor.dtype == src_dtype: |
|
tensor = tensor.to(dtype=tgt_dtype) |
|
updated = True |
|
return tensor, updated |
|
|
|
|
|
class QuickGELU(nn.Module): |
|
|
|
def forward(self, x: torch.Tensor): |
|
return x * torch.sigmoid(1.702 * x) |
|
|
|
|
|
class SelectElement(nn.Module): |
|
def __init__(self, index) -> None: |
|
super().__init__() |
|
self.index = index |
|
|
|
def forward(self, x): |
|
assert x.ndim >= 3 |
|
return x[:, self.index, ...] |
|
|
|
|
|
class SelectEOSAndProject(nn.Module): |
|
""" |
|
Text Pooling used in OpenCLIP |
|
""" |
|
|
|
def __init__(self, proj: nn.Module) -> None: |
|
super().__init__() |
|
self.proj = proj |
|
|
|
def forward(self, x, seq_len): |
|
assert x.ndim == 3 |
|
|
|
|
|
x = x[torch.arange(x.shape[0]), seq_len] |
|
x = self.proj(x) |
|
return x |
|
|