zetavg commited on
Commit
9bd8d8b
β€’
1 Parent(s): 40a8f4e

support .yaml config

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +24 -1
  3. config.yaml.sample +22 -0
.gitignore CHANGED
@@ -3,5 +3,6 @@ __pycache__/
3
  /venv
4
  .vscode
5
 
 
6
  /wandb
7
  /data
 
3
  /venv
4
  .vscode
5
 
6
+ /config.yaml
7
  /wandb
8
  /data
app.py CHANGED
@@ -1,7 +1,9 @@
1
  from typing import Union
2
 
3
- import fire
4
  import gradio as gr
 
 
 
5
 
6
  from llama_lora.config import Config, process_config
7
  from llama_lora.globals import initialize_global
@@ -41,6 +43,14 @@ def main(
41
  :param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
42
  '''
43
 
 
 
 
 
 
 
 
 
44
  if base_model is not None:
45
  Config.default_base_model_name = base_model
46
 
@@ -91,5 +101,18 @@ def main(
91
  server_name=server_name, share=share)
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  if __name__ == "__main__":
95
  fire.Fire(main)
 
1
  from typing import Union
2
 
 
3
  import gradio as gr
4
+ import fire
5
+ import os
6
+ import yaml
7
 
8
  from llama_lora.config import Config, process_config
9
  from llama_lora.globals import initialize_global
 
43
  :param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
44
  '''
45
 
46
+ config_from_file = read_yaml_config()
47
+ if config_from_file:
48
+ for key, value in config_from_file.items():
49
+ if not hasattr(Config, key):
50
+ available_keys = [k for k in vars(Config) if not k.startswith('__')]
51
+ raise ValueError(f"Invalid config key '{key}' in config.yaml. Available keys: {', '.join(available_keys)}")
52
+ setattr(Config, key, value)
53
+
54
  if base_model is not None:
55
  Config.default_base_model_name = base_model
56
 
 
101
  server_name=server_name, share=share)
102
 
103
 
104
+ def read_yaml_config():
105
+ app_dir = os.path.dirname(os.path.abspath(__file__))
106
+ config_path = os.path.join(app_dir, 'config.yaml')
107
+
108
+ if not os.path.exists(config_path):
109
+ return None
110
+
111
+ print(f"Loading config from {config_path}...")
112
+ with open(config_path, 'r') as yaml_file:
113
+ config = yaml.safe_load(yaml_file)
114
+ return config
115
+
116
+
117
  if __name__ == "__main__":
118
  fire.Fire(main)
config.yaml.sample ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Basic Configurations
2
+ data_dir: ./data
3
+ default_base_model_name: decapoda-research/llama-7b-hf
4
+ base_model_choices:
5
+ - decapoda-research/llama-7b-hf
6
+ - nomic-ai/gpt4all-j
7
+ load_8bit: false
8
+ trust_remote_code: false
9
+
10
+ # UI Customization
11
+ # ui_title: LLM Tuner
12
+ # ui_emoji: πŸ¦™πŸŽ›οΈ
13
+ # ui_subtitle: Have fun!
14
+ # ui_show_sys_info: true
15
+
16
+ # WandB
17
+ # enable_wandb: false
18
+ # wandb_api_key: ""
19
+ # default_wandb_project: LLM-Tuner
20
+
21
+ # Special Modes
22
+ ui_dev_mode: false