Spaces:
Runtime error
Runtime error
import copy | |
from transformers import BertConfig, CLIPVisionConfig | |
from transformers.configuration_utils import PretrainedConfig | |
from transformers.utils import logging | |
logger = logging.get_logger(__name__) | |
class CLIPVisionBertConfig(PretrainedConfig): | |
model_type = "clip-vision-bert" | |
is_composition = True | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
if "bert_config" not in kwargs: | |
raise ValueError("`bert_config` can not be `None`.") | |
if "clip_vision_config" not in kwargs: | |
raise ValueError("`clip_vision_config` can not be `None`.") | |
bert_config = kwargs.pop("bert_config") | |
clip_vision_config = kwargs.pop("clip_vision_config") | |
self.bert_config = BertConfig(**bert_config) | |
self.clip_vision_config = CLIPVisionConfig(**clip_vision_config) | |
def from_clip_vision_bert_configs( | |
cls, | |
clip_vision_config: PretrainedConfig, | |
bert_config: PretrainedConfig, | |
**kwargs | |
): | |
return cls( | |
clip_vision_config=clip_vision_config.to_dict(), | |
bert_config=bert_config.to_dict(), | |
**kwargs | |
) | |
def to_dict(self): | |
output = copy.deepcopy(self.__dict__) | |
output["clip_vision_config"] = self.clip_vision_config.to_dict() | |
output["bert_config"] = self.bert_config.to_dict() | |
output["model_type"] = self.__class__.model_type | |
return output | |