Handle dictionary in confgw
Browse files- configuration_aimv2.py +8 -3
configuration_aimv2.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import Any, Dict, Optional
|
2 |
|
3 |
from transformers.configuration_utils import PretrainedConfig
|
4 |
|
@@ -148,8 +148,8 @@ class AIMv2Config(PretrainedConfig):
|
|
148 |
|
149 |
def __init__(
|
150 |
self,
|
151 |
-
vision_config: Optional[AIMv2VisionConfig] = None,
|
152 |
-
text_config: Optional[AIMv2TextConfig] = None,
|
153 |
projection_dim: int = 768,
|
154 |
init_temperature: float = 0.07,
|
155 |
max_logit_scale: float = 100.0,
|
@@ -158,8 +158,13 @@ class AIMv2Config(PretrainedConfig):
|
|
158 |
super().__init__(**kwargs)
|
159 |
if vision_config is None:
|
160 |
vision_config = AIMv2VisionConfig()
|
|
|
|
|
|
|
161 |
if text_config is None:
|
162 |
text_config = AIMv2TextConfig()
|
|
|
|
|
163 |
|
164 |
self.vision_config = vision_config
|
165 |
self.text_config = text_config
|
|
|
1 |
+
from typing import Any, Dict, Optional, Union
|
2 |
|
3 |
from transformers.configuration_utils import PretrainedConfig
|
4 |
|
|
|
148 |
|
149 |
def __init__(
|
150 |
self,
|
151 |
+
vision_config: Optional[Union[AIMv2VisionConfig, Dict[str, Any]]] = None,
|
152 |
+
text_config: Optional[Union[AIMv2TextConfig, Dict[str, Any]]] = None,
|
153 |
projection_dim: int = 768,
|
154 |
init_temperature: float = 0.07,
|
155 |
max_logit_scale: float = 100.0,
|
|
|
158 |
super().__init__(**kwargs)
|
159 |
if vision_config is None:
|
160 |
vision_config = AIMv2VisionConfig()
|
161 |
+
elif isinstance(vision_config, dict):
|
162 |
+
vision_config = AIMv2VisionConfig(**vision_config)
|
163 |
+
|
164 |
if text_config is None:
|
165 |
text_config = AIMv2TextConfig()
|
166 |
+
elif isinstance(text_config, dict):
|
167 |
+
text_config = AIMv2TextConfig(**text_config)
|
168 |
|
169 |
self.vision_config = vision_config
|
170 |
self.text_config = text_config
|