sachin commited on
Commit
571c526
1 Parent(s): 18cb46c

Instantiated the nested PretrainedConfig correctly. Added a test to demo

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. .vscode/settings.json +4 -0
  3. src/config.py +15 -0
  4. tests/test_config.py +13 -0
.gitignore CHANGED
@@ -3,3 +3,4 @@
3
  pyrightconfig.json
4
  *.jpg
5
  *.pyc
 
 
3
  pyrightconfig.json
4
  *.jpg
5
  *.pyc
6
+ .env
.vscode/settings.json CHANGED
@@ -20,4 +20,8 @@
20
  // },
21
  // },
22
  // "isort.args":["--profile", "black"],
 
 
 
 
23
  }
 
20
  // },
21
  // },
22
  // "isort.args":["--profile", "black"],
23
+ "python.testing.unittestEnabled": false,
24
+ "python.testing.pytestEnabled": true,
25
+ "python.testing.cwd": "${workspaceFolder}/",
26
+ "python.envFile": "${workspaceFolder}/.env",
27
  }
src/config.py CHANGED
@@ -99,6 +99,16 @@ class TinyCLIPConfig(PretrainedConfig):
99
  self.loss_type = loss_type
100
  super().__init__(**kwargs)
101
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  class TrainerConfig(pydantic.BaseModel):
104
  epochs: int = 20
@@ -119,3 +129,8 @@ class TrainerConfig(pydantic.BaseModel):
119
 
120
  _model_config: TinyCLIPConfig = TinyCLIPConfig()
121
  _data_config: DataConfig = DataConfig()
 
 
 
 
 
 
99
  self.loss_type = loss_type
100
  super().__init__(**kwargs)
101
 
102
+ @classmethod
103
+ def from_dict(cls, config_dict, **kwargs):
104
+ text_config_dict = config_dict.pop("text_config", {})
105
+ text_config = TinyCLIPTextConfig.from_dict(text_config_dict)
106
+
107
+ vision_config_dict = config_dict.pop("vision_config", {})
108
+ vision_config = TinyCLIPVisionConfig.from_dict(vision_config_dict)
109
+
110
+ return cls(text_config=text_config, vision_config=vision_config, **config_dict, **kwargs)
111
+
112
 
113
  class TrainerConfig(pydantic.BaseModel):
114
  epochs: int = 20
 
129
 
130
  _model_config: TinyCLIPConfig = TinyCLIPConfig()
131
  _data_config: DataConfig = DataConfig()
132
+
133
+ def __init__(self, **data):
134
+ super().__init__(**data)
135
+ if "_model_config" in data:
136
+ self._model_config = TinyCLIPConfig.from_dict(data["_model_config"])
tests/test_config.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src import config
2
+ import json
3
+
4
+
5
+ def test_trainer_config():
6
+ trainer_config = config.TrainerConfig.model_validate_json(
7
+ json.dumps({"epochs": 21, "_model_config": {"text_config": {"text_model": "test"}}})
8
+ )
9
+
10
+ assert trainer_config.epochs == 21
11
+ assert trainer_config._model_config.text_config.text_model == "test"
12
+ assert hasattr(trainer_config._model_config.text_config, "max_len")
13
+ assert trainer_config._model_config.vision_config == config.TinyCLIPVisionConfig()