zetavg commited on
Commit
a5d7977
Β·
1 Parent(s): fdcd724

add wandb support

Browse files
.gitignore CHANGED
@@ -3,4 +3,5 @@ __pycache__/
3
  /venv
4
  .vscode
5
 
 
6
  /data
 
3
  /venv
4
  .vscode
5
 
6
+ /wandb
7
  /data
README.md CHANGED
@@ -60,13 +60,14 @@ file_mounts:
60
  setup: |
61
  git clone https://github.com/zetavg/LLaMA-LoRA-Tuner.git llama_lora_tuner
62
  cd llama_lora_tuner && pip install -r requirements.lock.txt
 
63
  cd ..
64
  echo 'Dependencies installed.'
65
 
66
  # Start the app.
67
  run: |
68
  echo 'Starting...'
69
- python llama_lora_tuner/app.py --data_dir='/data' --base_model='decapoda-research/llama-7b-hf' --share
70
  ```
71
 
72
  Then launch a cluster to run the task:
 
60
  setup: |
61
  git clone https://github.com/zetavg/LLaMA-LoRA-Tuner.git llama_lora_tuner
62
  cd llama_lora_tuner && pip install -r requirements.lock.txt
63
+ pip install wandb
64
  cd ..
65
  echo 'Dependencies installed.'
66
 
67
  # Start the app.
68
  run: |
69
  echo 'Starting...'
70
+ python llama_lora_tuner/app.py --data_dir='/data' --wandb_api_key "$([ -f /data/secrets/wandb_api_key ] && cat /data/secrets/wandb_api_key | tr -d '\n')" --base_model='decapoda-research/llama-7b-hf' --share
71
  ```
72
 
73
  Then launch a cluster to run the task:
app.py CHANGED
@@ -5,21 +5,37 @@ import fire
5
  import gradio as gr
6
 
7
  from llama_lora.globals import Global
 
8
  from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
9
  from llama_lora.utils.data import init_data_dir
10
 
11
 
 
12
  def main(
13
- load_8bit: bool = False,
14
  base_model: str = "",
15
  data_dir: str = "",
16
  # Allows to listen on all interfaces by providing '0.0.0.0'.
17
  server_name: str = "127.0.0.1",
18
  share: bool = False,
19
  skip_loading_base_model: bool = False,
 
20
  ui_show_sys_info: bool = True,
21
  ui_dev_mode: bool = False,
 
 
22
  ):
 
 
 
 
 
 
 
 
 
 
 
 
23
  base_model = base_model or os.environ.get("LLAMA_LORA_BASE_MODEL", "")
24
  data_dir = data_dir or os.environ.get("LLAMA_LORA_DATA_DIR", "")
25
  assert (
@@ -34,12 +50,22 @@ def main(
34
  Global.data_dir = os.path.abspath(data_dir)
35
  Global.load_8bit = load_8bit
36
 
 
 
 
 
 
 
 
37
  Global.ui_dev_mode = ui_dev_mode
38
  Global.ui_show_sys_info = ui_show_sys_info
39
 
40
  os.makedirs(data_dir, exist_ok=True)
41
  init_data_dir()
42
 
 
 
 
43
  with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
44
  main_page()
45
 
 
5
  import gradio as gr
6
 
7
  from llama_lora.globals import Global
8
+ from llama_lora.models import prepare_base_model
9
  from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
10
  from llama_lora.utils.data import init_data_dir
11
 
12
 
13
+
14
  def main(
 
15
  base_model: str = "",
16
  data_dir: str = "",
17
  # Allows to listen on all interfaces by providing '0.0.0.0'.
18
  server_name: str = "127.0.0.1",
19
  share: bool = False,
20
  skip_loading_base_model: bool = False,
21
+ load_8bit: bool = False,
22
  ui_show_sys_info: bool = True,
23
  ui_dev_mode: bool = False,
24
+ wandb_api_key: str = "",
25
+ wandb_project: str = "",
26
  ):
27
+ '''
28
+ Start the LLaMA-LoRA Tuner UI.
29
+
30
+ :param base_model: (required) The name of the default base model to use.
31
+ :param data_dir: (required) The path to the directory to store data.
32
+ :param server_name: Allows to listen on all interfaces by providing '0.0.0.0'.
33
+ :param share: Create a public Gradio URL.
34
+
35
+ :param wandb_api_key: The API key for Weights & Biases. Setting either this or `wandb_project` will enable Weights & Biases.
36
+ :param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
37
+ '''
38
+
39
  base_model = base_model or os.environ.get("LLAMA_LORA_BASE_MODEL", "")
40
  data_dir = data_dir or os.environ.get("LLAMA_LORA_DATA_DIR", "")
41
  assert (
 
50
  Global.data_dir = os.path.abspath(data_dir)
51
  Global.load_8bit = load_8bit
52
 
53
+ if len(wandb_api_key) > 0:
54
+ Global.enable_wandb = True
55
+ Global.wandb_api_key = wandb_api_key
56
+ if len(wandb_project) > 0:
57
+ Global.enable_wandb = True
58
+ Global.wandb_project = wandb_project
59
+
60
  Global.ui_dev_mode = ui_dev_mode
61
  Global.ui_show_sys_info = ui_show_sys_info
62
 
63
  os.makedirs(data_dir, exist_ok=True)
64
  init_data_dir()
65
 
66
+ if (not skip_loading_base_model) and (not ui_dev_mode):
67
+ prepare_base_model(base_model)
68
+
69
  with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
70
  main_page()
71
 
llama_lora/globals.py CHANGED
@@ -40,6 +40,11 @@ class Global:
40
  gpu_total_cores = None # GPU total cores
41
  gpu_total_memory = None
42
 
 
 
 
 
 
43
  # UI related
44
  ui_title: str = "LLaMA-LoRA Tuner"
45
  ui_emoji: str = "πŸ¦™πŸŽ›οΈ"
 
40
  gpu_total_cores = None # GPU total cores
41
  gpu_total_memory = None
42
 
43
+ # WandB
44
+ enable_wandb = False
45
+ wandb_api_key = None
46
+ default_wandb_project = "llama-lora-tuner"
47
+
48
  # UI related
49
  ui_title: str = "LLaMA-LoRA Tuner"
50
  ui_emoji: str = "πŸ¦™πŸŽ›οΈ"
llama_lora/lib/finetune.py CHANGED
@@ -50,8 +50,32 @@ def train(
50
  save_total_limit: int = 3,
51
  logging_steps: int = 10,
52
  # logging
53
- callbacks: List[Any] = []
 
 
 
 
 
 
54
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  if os.path.exists(output_dir):
56
  if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
57
  raise ValueError(
@@ -204,8 +228,8 @@ def train(
204
  load_best_model_at_end=True if val_set_size > 0 else False,
205
  ddp_find_unused_parameters=False if ddp else None,
206
  group_by_length=group_by_length,
207
- # report_to="wandb" if use_wandb else None,
208
- # run_name=wandb_run_name if use_wandb else None,
209
  ),
210
  data_collator=transformers.DataCollatorForSeq2Seq(
211
  tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
 
50
  save_total_limit: int = 3,
51
  logging_steps: int = 10,
52
  # logging
53
+ callbacks: List[Any] = [],
54
+ # wandb params
55
+ wandb_api_key = None,
56
+ wandb_project: str = "",
57
+ wandb_run_name: str = "",
58
+ wandb_watch: str = "false", # options: false | gradients | all
59
+ wandb_log_model: str = "true", # options: false | true
60
  ):
61
+ if wandb_api_key:
62
+ os.environ["WANDB_API_KEY"] = wandb_api_key
63
+ if wandb_project:
64
+ os.environ["WANDB_PROJECT"] = wandb_project
65
+ if wandb_run_name:
66
+ os.environ["WANDB_RUN_NAME"] = wandb_run_name
67
+ if wandb_watch:
68
+ os.environ["WANDB_WATCH"] = wandb_watch
69
+ if wandb_log_model:
70
+ os.environ["WANDB_LOG_MODEL"] = wandb_log_model
71
+ use_wandb = (wandb_project and len(wandb_project) > 0) or (
72
+ "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
73
+ )
74
+ if use_wandb:
75
+ os.environ['WANDB_MODE'] = "online"
76
+ else:
77
+ os.environ['WANDB_MODE'] = "disabled"
78
+
79
  if os.path.exists(output_dir):
80
  if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
81
  raise ValueError(
 
228
  load_best_model_at_end=True if val_set_size > 0 else False,
229
  ddp_find_unused_parameters=False if ddp else None,
230
  group_by_length=group_by_length,
231
+ report_to="wandb" if use_wandb else None,
232
+ run_name=wandb_run_name if use_wandb else None,
233
  ),
234
  data_collator=transformers.DataCollatorForSeq2Seq(
235
  tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
llama_lora/ui/finetune_ui.py CHANGED
@@ -491,7 +491,10 @@ Train data (first 10):
491
  save_steps, # save_steps
492
  save_total_limit, # save_total_limit
493
  logging_steps, # logging_steps
494
- training_callbacks # callbacks
 
 
 
495
  )
496
 
497
  logs_str = "\n".join([json.dumps(log)
 
491
  save_steps, # save_steps
492
  save_total_limit, # save_total_limit
493
  logging_steps, # logging_steps
494
+ training_callbacks, # callbacks
495
+ Global.wandb_api_key, # wandb_api_key
496
+ Global.default_wandb_project if Global.enable_wandb else None, # wandb_project
497
+ model_name # wandb_run_name
498
  )
499
 
500
  logs_str = "\n".join([json.dumps(log)