|
import json |
|
|
|
DEFAULTS = { |
|
"network": "dpn", |
|
"encoder": "dpn92", |
|
"model_params": {}, |
|
"optimizer": { |
|
"batch_size": 32, |
|
"type": "SGD", |
|
"momentum": 0.9, |
|
"weight_decay": 0, |
|
"clip": 1., |
|
"learning_rate": 0.1, |
|
"classifier_lr": -1, |
|
"nesterov": True, |
|
"schedule": { |
|
"type": "constant", |
|
"mode": "epoch", |
|
"epochs": 10, |
|
"params": {} |
|
} |
|
}, |
|
"normalize": { |
|
"mean": [0.485, 0.456, 0.406], |
|
"std": [0.229, 0.224, 0.225] |
|
} |
|
} |
|
|
|
|
|
def _merge(src, dst): |
|
for k, v in src.items(): |
|
if k in dst: |
|
if isinstance(v, dict): |
|
_merge(src[k], dst[k]) |
|
else: |
|
dst[k] = v |
|
|
|
|
|
def load_config(config_file, defaults=DEFAULTS): |
|
with open(config_file, "r") as fd: |
|
config = json.load(fd) |
|
_merge(defaults, config) |
|
return config |
|
|