M3D-CLIP / configuration_m3d_clip.py
GoodBaiBai88's picture
Upload M3DCLIP
8c45550 verified
from transformers import PretrainedConfig
class M3DCLIPConfig(PretrainedConfig):
model_type = "m3d_clip"
def __init__(
self,
language_model_name_or_path: str = 'bert-base-uncased',
local_loss: bool = False,
gather_loss: bool = True,
in_channels: int = 1,
img_size: tuple = (32, 256, 256),
patch_size: tuple = (4, 16, 16),
hidden_size: int = 768,
mlp_dim: int = 3072,
num_layers: int = 12,
num_heads: int = 12,
pos_embed: str = "perceptron",
dropout_rate: float = 0,
spatial_dims: int = 3,
max_text_len: int = 128,
vocab_size: int = 30522,
**kwargs,
):
self.language_model_name_or_path = language_model_name_or_path
self.in_channels = in_channels
self.img_size = img_size
self.patch_size = patch_size
self.hidden_size = hidden_size
self.mlp_dim = mlp_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.pos_embed = pos_embed
self.dropout_rate = dropout_rate
self.spatial_dims = spatial_dims
self.local_loss = local_loss
self.gather_loss = gather_loss
self.max_text_len = max_text_len
self.vocab_size = vocab_size
super().__init__(**kwargs)