Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import Type | |
class MLPBlock(nn.Module): | |
def __init__( | |
self, | |
embedding_dim: int, | |
mlp_dim: int, | |
act: Type[nn.Module] = nn.GELU, | |
) -> None: | |
super().__init__() | |
self.lin1 = nn.Linear(embedding_dim, mlp_dim) | |
self.lin2 = nn.Linear(mlp_dim, embedding_dim) | |
self.act = act() | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.lin2(self.act(self.lin1(x))) | |
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa | |
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa | |
class LayerNorm2d(nn.Module): | |
def __init__(self, num_channels: int, eps: float = 1e-6) -> None: | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(num_channels)) | |
self.bias = nn.Parameter(torch.zeros(num_channels)) | |
self.eps = eps | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
u = x.mean(1, keepdim=True) | |
s = (x - u).pow(2).mean(1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.eps) | |
x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
return x | |
def val2list(x: list or tuple or any, repeat_time=1) -> list: | |
if isinstance(x, (list, tuple)): | |
return list(x) | |
return [x for _ in range(repeat_time)] | |
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: | |
x = val2list(x) | |
# repeat elements if necessary | |
if len(x) > 0: | |
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] | |
return tuple(x) | |
def list_sum(x: list) -> any: | |
return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) | |
def resize( | |
x: torch.Tensor, | |
size: any or None = None, | |
scale_factor=None, | |
mode: str = "bicubic", | |
align_corners: bool or None = False, | |
) -> torch.Tensor: | |
if mode in ["bilinear", "bicubic"]: | |
return F.interpolate( | |
x, | |
size=size, | |
scale_factor=scale_factor, | |
mode=mode, | |
align_corners=align_corners, | |
) | |
elif mode in ["nearest", "area"]: | |
return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode) | |
else: | |
raise NotImplementedError(f"resize(mode={mode}) not implemented.") | |
class UpSampleLayer(nn.Module): | |
def __init__( | |
self, | |
mode="bicubic", | |
size=None, | |
factor=2, | |
align_corners=False, | |
): | |
super(UpSampleLayer, self).__init__() | |
self.mode = mode | |
self.size = val2list(size, 2) if size is not None else None | |
self.factor = None if self.size is not None else factor | |
self.align_corners = align_corners | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return resize(x, self.size, self.factor, self.mode, self.align_corners) | |
class OpSequential(nn.Module): | |
def __init__(self, op_list): | |
super(OpSequential, self).__init__() | |
valid_op_list = [] | |
for op in op_list: | |
if op is not None: | |
valid_op_list.append(op) | |
self.op_list = nn.ModuleList(valid_op_list) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
for op in self.op_list: | |
x = op(x) | |
return x |