jadechoghari commited on
Commit
7c8fd9c
·
verified ·
1 Parent(s): 898456b

Create config_utils.py

Browse files
Files changed (1) hide show
  1. config_utils.py +40 -0
config_utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from omegaconf import OmegaConf, DictConfig
3
+ import os
4
+
5
+ def load_config(print_config = True):
6
+ parser = argparse.ArgumentParser()
7
+ parser.add_argument('--config', type=str,
8
+ default='configs/tea-pour.yaml',
9
+ help="Config file path")
10
+ args = parser.parse_args()
11
+ config = OmegaConf.load(args.config)
12
+
13
+ # Recursively merge base configs
14
+ cur_config_path = args.config
15
+ cur_config = config
16
+ while "base_config" in cur_config and cur_config.base_config != cur_config_path:
17
+ base_config = OmegaConf.load(cur_config.base_config)
18
+ config = OmegaConf.merge(base_config, config)
19
+ cur_config_path = cur_config.base_config
20
+ cur_config = base_config
21
+
22
+ prompt = config.generation.prompt
23
+ if isinstance(prompt, str):
24
+ prompt = {"edit": prompt}
25
+ config.generation.prompt = prompt
26
+ OmegaConf.resolve(config)
27
+ if print_config:
28
+ print("[INFO] loaded config:")
29
+ print(OmegaConf.to_yaml(config))
30
+
31
+ return config
32
+
33
+ def save_config(config: DictConfig, path, gene = False, inv = False):
34
+ os.makedirs(path, exist_ok = True)
35
+ config = OmegaConf.create(config)
36
+ if gene:
37
+ config.pop("inversion")
38
+ if inv:
39
+ config.pop("generation")
40
+ OmegaConf.save(config, os.path.join(path, "config.yaml"))