Spaces:
Runtime error
Runtime error
from torch import nn, Tensor | |
from typing import Union, Optional, Tuple | |
class BaseProjector(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x: Tensor) -> Tensor: | |
raise NotImplementedError | |
class LinearProjector(BaseProjector): | |
def __init__(self, in_dim, out_dim): | |
super().__init__() | |
self.fc = nn.Linear(in_dim, out_dim) | |
def forward(self, x: Tensor) -> Tensor: | |
return self.fc(x) | |
class AdapterProjector(BaseProjector): | |
def __init__(self, in_dim, mid_dim, out_dim): | |
super().__init__() | |
self.fc = nn.Sequential( | |
nn.Linear(in_dim, mid_dim, bias=False), | |
nn.ReLU(inplace=True), | |
nn.Linear(mid_dim, out_dim, bias=False), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, x: Tensor) -> Tensor: | |
return self.fc(x) | |
def create_projectors(dims): | |
if len(dims) == 0: | |
return nn.Identity() | |
elif len(dims) == 2: | |
return LinearProjector(*dims) | |
elif len(dims) == 3: | |
return AdapterProjector(*dims) | |
else: | |
raise NotImplementedError | |