Spaces:
Runtime error
Runtime error
File size: 1,005 Bytes
a65550c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch
from .masked_drop import MaskedDrop
from .spatial_pool import SpatialPool
from .perceiver import PerceiverResampler
from .qformer import Qformer
class IdentityMap(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_resampler_type": None}
def build_vision_resampler(model_args, delay_load=False, **kwargs):
resampler_type = getattr(model_args, "mm_resampler_type", None)
if resampler_type == "masked_drop":
return MaskedDrop(model_args)
elif resampler_type == "spatial_pool":
return SpatialPool(model_args, **kwargs)
elif resampler_type == "perceiver":
return PerceiverResampler(model_args, **kwargs)
elif resampler_type == "qformer":
return Qformer(model_args, **kwargs)
elif resampler_type is None:
return IdentityMap()
raise ValueError(f"Unknown resampler type: {resampler_type}")
|