michalk8 commited on
Commit
fb57492
·
1 Parent(s): 5fd59c5

Handle dictionary in confgw

Browse files
Files changed (1) hide show
  1. configuration_aimv2.py +8 -3
configuration_aimv2.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Optional
2
 
3
  from transformers.configuration_utils import PretrainedConfig
4
 
@@ -148,8 +148,8 @@ class AIMv2Config(PretrainedConfig):
148
 
149
  def __init__(
150
  self,
151
- vision_config: Optional[AIMv2VisionConfig] = None,
152
- text_config: Optional[AIMv2TextConfig] = None,
153
  projection_dim: int = 768,
154
  init_temperature: float = 0.07,
155
  max_logit_scale: float = 100.0,
@@ -158,8 +158,13 @@ class AIMv2Config(PretrainedConfig):
158
  super().__init__(**kwargs)
159
  if vision_config is None:
160
  vision_config = AIMv2VisionConfig()
 
 
 
161
  if text_config is None:
162
  text_config = AIMv2TextConfig()
 
 
163
 
164
  self.vision_config = vision_config
165
  self.text_config = text_config
 
1
+ from typing import Any, Dict, Optional, Union
2
 
3
  from transformers.configuration_utils import PretrainedConfig
4
 
 
148
 
149
  def __init__(
150
  self,
151
+ vision_config: Optional[Union[AIMv2VisionConfig, Dict[str, Any]]] = None,
152
+ text_config: Optional[Union[AIMv2TextConfig, Dict[str, Any]]] = None,
153
  projection_dim: int = 768,
154
  init_temperature: float = 0.07,
155
  max_logit_scale: float = 100.0,
 
158
  super().__init__(**kwargs)
159
  if vision_config is None:
160
  vision_config = AIMv2VisionConfig()
161
+ elif isinstance(vision_config, dict):
162
+ vision_config = AIMv2VisionConfig(**vision_config)
163
+
164
  if text_config is None:
165
  text_config = AIMv2TextConfig()
166
+ elif isinstance(text_config, dict):
167
+ text_config = AIMv2TextConfig(**text_config)
168
 
169
  self.vision_config = vision_config
170
  self.text_config = text_config