Instantiated the nested PretrainedConfig correctly. Added a test to demo
Browse files- .gitignore +1 -0
- .vscode/settings.json +4 -0
- src/config.py +15 -0
- tests/test_config.py +13 -0
.gitignore
CHANGED
@@ -3,3 +3,4 @@
|
|
3 |
pyrightconfig.json
|
4 |
*.jpg
|
5 |
*.pyc
|
|
|
|
3 |
pyrightconfig.json
|
4 |
*.jpg
|
5 |
*.pyc
|
6 |
+
.env
|
.vscode/settings.json
CHANGED
@@ -20,4 +20,8 @@
|
|
20 |
// },
|
21 |
// },
|
22 |
// "isort.args":["--profile", "black"],
|
|
|
|
|
|
|
|
|
23 |
}
|
|
|
20 |
// },
|
21 |
// },
|
22 |
// "isort.args":["--profile", "black"],
|
23 |
+
"python.testing.unittestEnabled": false,
|
24 |
+
"python.testing.pytestEnabled": true,
|
25 |
+
"python.testing.cwd": "${workspaceFolder}/",
|
26 |
+
"python.envFile": "${workspaceFolder}/.env",
|
27 |
}
|
src/config.py
CHANGED
@@ -99,6 +99,16 @@ class TinyCLIPConfig(PretrainedConfig):
|
|
99 |
self.loss_type = loss_type
|
100 |
super().__init__(**kwargs)
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
class TrainerConfig(pydantic.BaseModel):
|
104 |
epochs: int = 20
|
@@ -119,3 +129,8 @@ class TrainerConfig(pydantic.BaseModel):
|
|
119 |
|
120 |
_model_config: TinyCLIPConfig = TinyCLIPConfig()
|
121 |
_data_config: DataConfig = DataConfig()
|
|
|
|
|
|
|
|
|
|
|
|
99 |
self.loss_type = loss_type
|
100 |
super().__init__(**kwargs)
|
101 |
|
102 |
+
@classmethod
|
103 |
+
def from_dict(cls, config_dict, **kwargs):
|
104 |
+
text_config_dict = config_dict.pop("text_config", {})
|
105 |
+
text_config = TinyCLIPTextConfig.from_dict(text_config_dict)
|
106 |
+
|
107 |
+
vision_config_dict = config_dict.pop("vision_config", {})
|
108 |
+
vision_config = TinyCLIPVisionConfig.from_dict(vision_config_dict)
|
109 |
+
|
110 |
+
return cls(text_config=text_config, vision_config=vision_config, **config_dict, **kwargs)
|
111 |
+
|
112 |
|
113 |
class TrainerConfig(pydantic.BaseModel):
|
114 |
epochs: int = 20
|
|
|
129 |
|
130 |
_model_config: TinyCLIPConfig = TinyCLIPConfig()
|
131 |
_data_config: DataConfig = DataConfig()
|
132 |
+
|
133 |
+
def __init__(self, **data):
|
134 |
+
super().__init__(**data)
|
135 |
+
if "_model_config" in data:
|
136 |
+
self._model_config = TinyCLIPConfig.from_dict(data["_model_config"])
|
tests/test_config.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src import config
|
2 |
+
import json
|
3 |
+
|
4 |
+
|
5 |
+
def test_trainer_config():
|
6 |
+
trainer_config = config.TrainerConfig.model_validate_json(
|
7 |
+
json.dumps({"epochs": 21, "_model_config": {"text_config": {"text_model": "test"}}})
|
8 |
+
)
|
9 |
+
|
10 |
+
assert trainer_config.epochs == 21
|
11 |
+
assert trainer_config._model_config.text_config.text_model == "test"
|
12 |
+
assert hasattr(trainer_config._model_config.text_config, "max_len")
|
13 |
+
assert trainer_config._model_config.vision_config == config.TinyCLIPVisionConfig()
|