zetavg commited on
Commit
b39fdac
·
unverified ·
1 Parent(s): 38fb491

support adding LoRA Target Modules choices

Browse files
llama_lora/ui/finetune_ui.py CHANGED
@@ -573,6 +573,7 @@ def handle_load_params_from_model(
573
  save_steps,
574
  save_total_limit,
575
  logging_steps,
 
576
  ):
577
  error_message = ""
578
  notice_message = ""
@@ -621,6 +622,9 @@ def handle_load_params_from_model(
621
  lora_dropout = value
622
  elif key == "lora_target_modules":
623
  lora_target_modules = value
 
 
 
624
  elif key == "save_steps":
625
  save_steps = value
626
  elif key == "save_total_limit":
@@ -658,13 +662,24 @@ def handle_load_params_from_model(
658
  lora_r,
659
  lora_alpha,
660
  lora_dropout,
661
- lora_target_modules,
662
  save_steps,
663
  save_total_limit,
664
  logging_steps,
 
665
  )
666
 
667
 
 
 
 
 
 
 
 
 
 
 
668
  def finetune_ui():
669
  things_that_might_timeout = []
670
 
@@ -896,12 +911,31 @@ def finetune_ui():
896
  info="The dropout probability for LoRA, which controls the fraction of LoRA parameters that are set to zero during training. A larger lora_dropout increases the regularization effect of LoRA but also increases the risk of underfitting."
897
  )
898
 
 
 
899
  lora_target_modules = gr.CheckboxGroup(
900
  label="LoRA Target Modules",
901
- choices=["q_proj", "k_proj", "v_proj", "o_proj"],
902
  value=["q_proj", "v_proj"],
903
- info="Modules to replace with LoRA."
 
904
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
905
 
906
  with gr.Row():
907
  logging_steps = gr.Number(
@@ -926,6 +960,7 @@ def finetune_ui():
926
  with gr.Column():
927
  model_name = gr.Textbox(
928
  lines=1, label="LoRA Model Name", value=random_name,
 
929
  info="The name of the new LoRA model.",
930
  elem_id="finetune_model_name",
931
  )
@@ -993,8 +1028,8 @@ def finetune_ui():
993
  things_that_might_timeout.append(
994
  load_params_from_model_btn.click(
995
  fn=handle_load_params_from_model,
996
- inputs=[continue_from_model] + finetune_args,
997
- outputs=[load_params_from_model_message] + finetune_args
998
  )
999
  )
1000
 
 
573
  save_steps,
574
  save_total_limit,
575
  logging_steps,
576
+ lora_target_module_choices,
577
  ):
578
  error_message = ""
579
  notice_message = ""
 
622
  lora_dropout = value
623
  elif key == "lora_target_modules":
624
  lora_target_modules = value
625
+ for element in value:
626
+ if element not in lora_target_module_choices:
627
+ lora_target_module_choices.append(element)
628
  elif key == "save_steps":
629
  save_steps = value
630
  elif key == "save_total_limit":
 
662
  lora_r,
663
  lora_alpha,
664
  lora_dropout,
665
+ gr.CheckboxGroup.update(value=lora_target_modules, choices=lora_target_module_choices),
666
  save_steps,
667
  save_total_limit,
668
  logging_steps,
669
+ lora_target_module_choices,
670
  )
671
 
672
 
673
+ default_lora_target_module_choices = ["q_proj", "k_proj", "v_proj", "o_proj"]
674
+
675
+
676
+ def handle_lora_target_modules_add(choices, new_module, selected_modules):
677
+ choices.append(new_module)
678
+ selected_modules.append(new_module)
679
+
680
+ return (choices, "", gr.CheckboxGroup.update(value=selected_modules, choices=choices))
681
+
682
+
683
  def finetune_ui():
684
  things_that_might_timeout = []
685
 
 
911
  info="The dropout probability for LoRA, which controls the fraction of LoRA parameters that are set to zero during training. A larger lora_dropout increases the regularization effect of LoRA but also increases the risk of underfitting."
912
  )
913
 
914
+ lora_target_module_choices = gr.State(value=default_lora_target_module_choices)
915
+
916
  lora_target_modules = gr.CheckboxGroup(
917
  label="LoRA Target Modules",
918
+ choices=default_lora_target_module_choices,
919
  value=["q_proj", "v_proj"],
920
+ info="Modules to replace with LoRA.",
921
+ elem_id="finetune_lora_target_modules"
922
  )
923
+ with gr.Box(elem_id="finetune_lora_target_modules_add_box"):
924
+ with gr.Row():
925
+ lora_target_modules_add = gr.Textbox(
926
+ lines=1, max_lines=1, show_label=False,
927
+ elem_id="finetune_lora_target_modules_add"
928
+ )
929
+ lora_target_modules_add_btn = gr.Button(
930
+ "Add",
931
+ elem_id="finetune_lora_target_modules_add_btn"
932
+ )
933
+ lora_target_modules_add_btn.style(full_width=False, size="sm")
934
+ things_that_might_timeout.append(lora_target_modules_add_btn.click(
935
+ handle_lora_target_modules_add,
936
+ inputs=[lora_target_module_choices, lora_target_modules_add, lora_target_modules],
937
+ outputs=[lora_target_module_choices, lora_target_modules_add, lora_target_modules],
938
+ ))
939
 
940
  with gr.Row():
941
  logging_steps = gr.Number(
 
960
  with gr.Column():
961
  model_name = gr.Textbox(
962
  lines=1, label="LoRA Model Name", value=random_name,
963
+ max_lines=1,
964
  info="The name of the new LoRA model.",
965
  elem_id="finetune_model_name",
966
  )
 
1028
  things_that_might_timeout.append(
1029
  load_params_from_model_btn.click(
1030
  fn=handle_load_params_from_model,
1031
+ inputs=[continue_from_model] + finetune_args + [lora_target_module_choices],
1032
+ outputs=[load_params_from_model_message] + finetune_args + [lora_target_module_choices]
1033
  )
1034
  )
1035
 
llama_lora/ui/main_page.py CHANGED
@@ -568,6 +568,27 @@ def main_page_custom_css():
568
  flex: 2;
569
  }
570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
  #finetune_save_total_limit,
572
  #finetune_save_steps,
573
  #finetune_logging_steps {
 
568
  flex: 2;
569
  }
570
 
571
+ #finetune_lora_target_modules_add_box {
572
+ margin-top: -24px;
573
+ padding-top: 8px;
574
+ border-top-left-radius: 0;
575
+ border-top-right-radius: 0;
576
+ border-top: 0;
577
+ }
578
+ #finetune_lora_target_modules_add_box > * > .form {
579
+ border: 0;
580
+ box-shadow: none;
581
+ }
582
+ #finetune_lora_target_modules_add {
583
+ padding: 0;
584
+ }
585
+ #finetune_lora_target_modules_add input {
586
+ padding: 4px 8px;
587
+ }
588
+ #finetune_lora_target_modules_add_btn {
589
+ min-width: 60px;
590
+ }
591
+
592
  #finetune_save_total_limit,
593
  #finetune_save_steps,
594
  #finetune_logging_steps {