zetavg commited on
Commit
cff173a
Β·
1 Parent(s): 69fa725
llama_lora/globals.py CHANGED
@@ -49,6 +49,7 @@ class Global:
49
  training_eta: Union[int, None] = None
50
  train_output: Union[None, Any] = None
51
  train_output_str: Union[None, str] = None
 
52
 
53
  # Generation Control
54
  should_stop_generating: bool = False
 
49
  training_eta: Union[int, None] = None
50
  train_output: Union[None, Any] = None
51
  train_output_str: Union[None, str] = None
52
+ training_params_info_text: str = ""
53
 
54
  # Generation Control
55
  should_stop_generating: bool = False
llama_lora/lib/finetune.py CHANGED
@@ -70,10 +70,12 @@ def train(
70
  wandb_tags: List[str] = [],
71
  wandb_watch: str = "false", # options: false | gradients | all
72
  wandb_log_model: str = "true", # options: false | true
 
73
  status_message_callback: Any = None,
 
74
  ):
75
  if status_message_callback:
76
- cb_result = status_message_callback("Preparing training...")
77
  if cb_result:
78
  return
79
 
@@ -163,6 +165,8 @@ def train(
163
  config={'finetune_args': finetune_args},
164
  # id=None # used for resuming
165
  )
 
 
166
  else:
167
  os.environ['WANDB_MODE'] = "disabled"
168
 
@@ -294,6 +298,10 @@ def train(
294
  if use_wandb and wandb:
295
  wandb.config.update({"model": {"all_params": all_params, "trainable_params": trainable_params,
296
  "trainable%": 100 * trainable_params / all_params}})
 
 
 
 
297
 
298
  if status_message_callback:
299
  cb_result = status_message_callback("Preparing train data...")
 
70
  wandb_tags: List[str] = [],
71
  wandb_watch: str = "false", # options: false | gradients | all
72
  wandb_log_model: str = "true", # options: false | true
73
+ additional_wandb_config: Union[dict, None] = None,
74
  status_message_callback: Any = None,
75
+ params_info_callback: Any = None,
76
  ):
77
  if status_message_callback:
78
+ cb_result = status_message_callback("Preparing...")
79
  if cb_result:
80
  return
81
 
 
165
  config={'finetune_args': finetune_args},
166
  # id=None # used for resuming
167
  )
168
+ if additional_wandb_config:
169
+ wandb.config.update(additional_wandb_config)
170
  else:
171
  os.environ['WANDB_MODE'] = "disabled"
172
 
 
298
  if use_wandb and wandb:
299
  wandb.config.update({"model": {"all_params": all_params, "trainable_params": trainable_params,
300
  "trainable%": 100 * trainable_params / all_params}})
301
+ if params_info_callback:
302
+ cb_result = params_info_callback(all_params=all_params, trainable_params=trainable_params)
303
+ if cb_result:
304
+ return
305
 
306
  if status_message_callback:
307
  cb_result = status_message_callback("Preparing train data...")
llama_lora/ui/finetune/style.css CHANGED
@@ -271,6 +271,7 @@
271
  #finetune_training_status .progress-block {
272
  min-height: 100px;
273
  display: flex;
 
274
  justify-content: center;
275
  align-items: center;
276
  background: var(--panel-background-fill);
@@ -300,11 +301,14 @@
300
  white-space: pre-wrap;
301
  }
302
  #finetune_training_status .progress-block .progress-level {
 
303
  display: flex;
304
  flex-direction: column;
 
305
  align-items: center;
306
  z-index: var(--layer-2);
307
  width: var(--size-full);
 
308
  }
309
  #finetune_training_status .progress-block .progress-level-inner {
310
  margin: var(--size-2) auto;
@@ -326,6 +330,17 @@
326
  transition: all 150ms ease 0s;
327
  }
328
 
 
 
 
 
 
 
 
 
 
 
 
329
  #finetune_training_status .progress-block .output {
330
  display: flex;
331
  flex-direction: column;
 
271
  #finetune_training_status .progress-block {
272
  min-height: 100px;
273
  display: flex;
274
+ flex-direction: column;
275
  justify-content: center;
276
  align-items: center;
277
  background: var(--panel-background-fill);
 
301
  white-space: pre-wrap;
302
  }
303
  #finetune_training_status .progress-block .progress-level {
304
+ flex-grow: 1;
305
  display: flex;
306
  flex-direction: column;
307
+ justify-content: center;
308
  align-items: center;
309
  z-index: var(--layer-2);
310
  width: var(--size-full);
311
+ padding: 8px 0;
312
  }
313
  #finetune_training_status .progress-block .progress-level-inner {
314
  margin: var(--size-2) auto;
 
330
  transition: all 150ms ease 0s;
331
  }
332
 
333
+ #finetune_training_status .progress-block .params-info {
334
+ font-size: var(--text-sm);
335
+ font-weight: var(--weight-light);
336
+ margin-top: 8px;
337
+ margin-bottom: -4px !important;
338
+ opacity: 0.4;
339
+ }
340
+ #finetune_training_status .progress-block .progress-level + .params-info {
341
+ margin-top: -8px;
342
+ }
343
+
344
  #finetune_training_status .progress-block .output {
345
  display: flex;
346
  flex-direction: column;
llama_lora/ui/finetune/training.py CHANGED
@@ -29,6 +29,10 @@ def status_message_callback(message):
29
  Global.training_status_text = message
30
 
31
 
 
 
 
 
32
  def do_train(
33
  # Dataset
34
  template,
@@ -262,6 +266,8 @@ def do_train(
262
  train_data=train_data,
263
  callbacks=training_callbacks,
264
  status_message_callback=status_message_callback,
 
 
265
  **finetune_args,
266
  )
267
 
@@ -325,6 +331,14 @@ def render_training_status():
325
  end_message = "βœ… Training completed"
326
  if Global.should_stop_training:
327
  end_message = "πŸ›‘ Train aborted"
 
 
 
 
 
 
 
 
328
  html_content = f"""
329
  <div class="progress-block">
330
  <div class="progress-level">
@@ -335,6 +349,7 @@ def render_training_status():
335
  <div class="message">{Global.train_output_str}</div>
336
  </div>
337
  </div>
 
338
  </div>
339
  """
340
  return (gr.HTML.update(value=html_content), gr.HTML.update(visible=False))
@@ -371,6 +386,13 @@ def render_training_status():
371
  else:
372
  meta_info.append(format_time(time_elapsed))
373
 
 
 
 
 
 
 
 
374
  html_content = f"""
375
  <div class="progress-block is_training">
376
  <div class="meta-text">{' | '.join(meta_info)}</div>
@@ -383,6 +405,7 @@ def render_training_status():
383
  </div>
384
  </div>
385
  </div>
 
386
  </div>
387
  """
388
  return (gr.HTML.update(value=html_content), gr.HTML.update(visible=True))
 
29
  Global.training_status_text = message
30
 
31
 
32
+ def params_info_callback(all_params, trainable_params):
33
+ Global.training_params_info_text = f"Params: {trainable_params}/{all_params} ({100 * trainable_params / all_params:.4f}% trainable)"
34
+
35
+
36
  def do_train(
37
  # Dataset
38
  template,
 
266
  train_data=train_data,
267
  callbacks=training_callbacks,
268
  status_message_callback=status_message_callback,
269
+ params_info_callback=params_info_callback,
270
+ additional_wandb_config=info,
271
  **finetune_args,
272
  )
273
 
 
331
  end_message = "βœ… Training completed"
332
  if Global.should_stop_training:
333
  end_message = "πŸ›‘ Train aborted"
334
+
335
+ params_info_html = ""
336
+ if Global.training_params_info_text:
337
+ params_info_html = f"""
338
+ <div class="params-info">
339
+ {Global.training_params_info_text}
340
+ </div>
341
+ """
342
  html_content = f"""
343
  <div class="progress-block">
344
  <div class="progress-level">
 
349
  <div class="message">{Global.train_output_str}</div>
350
  </div>
351
  </div>
352
+ {params_info_html}
353
  </div>
354
  """
355
  return (gr.HTML.update(value=html_content), gr.HTML.update(visible=False))
 
386
  else:
387
  meta_info.append(format_time(time_elapsed))
388
 
389
+ params_info_html = ""
390
+ if Global.training_params_info_text:
391
+ params_info_html = f"""
392
+ <div class="params-info">
393
+ {Global.training_params_info_text}
394
+ </div>
395
+ """
396
  html_content = f"""
397
  <div class="progress-block is_training">
398
  <div class="meta-text">{' | '.join(meta_info)}</div>
 
405
  </div>
406
  </div>
407
  </div>
408
+ {params_info_html}
409
  </div>
410
  """
411
  return (gr.HTML.update(value=html_content), gr.HTML.update(visible=True))
llama_lora/ui/trainer_callback.py CHANGED
@@ -24,6 +24,7 @@ def reset_training_status():
24
  Global.training_eta = None
25
  Global.train_output = None
26
  Global.train_output_str = None
 
27
 
28
 
29
  def get_progress_text(current_epoch, total_epochs, last_loss):
 
24
  Global.training_eta = None
25
  Global.train_output = None
26
  Global.train_output_str = None
27
+ Global.training_params_info_text = ""
28
 
29
 
30
  def get_progress_text(current_epoch, total_epochs, last_loss):