|
from transformers import PretrainedConfig |
|
|
|
|
|
|
|
class M3DCLIPConfig(PretrainedConfig): |
|
model_type = "m3d_clip" |
|
|
|
def __init__( |
|
self, |
|
language_model_name_or_path: str = 'bert-base-uncased', |
|
local_loss: bool = False, |
|
gather_loss: bool = True, |
|
in_channels: int = 1, |
|
img_size: tuple = (32, 256, 256), |
|
patch_size: tuple = (4, 16, 16), |
|
hidden_size: int = 768, |
|
mlp_dim: int = 3072, |
|
num_layers: int = 12, |
|
num_heads: int = 12, |
|
pos_embed: str = "perceptron", |
|
dropout_rate: float = 0, |
|
spatial_dims: int = 3, |
|
max_text_len: int = 128, |
|
vocab_size: int = 30522, |
|
**kwargs, |
|
): |
|
self.language_model_name_or_path = language_model_name_or_path |
|
self.in_channels = in_channels |
|
self.img_size = img_size |
|
self.patch_size = patch_size |
|
self.hidden_size = hidden_size |
|
self.mlp_dim = mlp_dim |
|
self.num_layers = num_layers |
|
self.num_heads = num_heads |
|
self.pos_embed = pos_embed |
|
self.dropout_rate = dropout_rate |
|
self.spatial_dims = spatial_dims |
|
self.local_loss = local_loss |
|
self.gather_loss = gather_loss |
|
self.max_text_len = max_text_len |
|
self.vocab_size = vocab_size |
|
super().__init__(**kwargs) |
|
|