fffiloni commited on
Commit
0fcf13e
1 Parent(s): 2d40264

add dataset_id arg

Browse files
Files changed (1) hide show
  1. train_dreambooth_lora_sdxl.py +11 -1
train_dreambooth_lora_sdxl.py CHANGED
@@ -64,7 +64,7 @@ logger = get_logger(__name__)
64
 
65
 
66
  def save_model_card(
67
- repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
68
  ):
69
  img_str = ""
70
  for i, image in enumerate(images):
@@ -82,6 +82,8 @@ tags:
82
  - diffusers
83
  - lora
84
  inference: false
 
 
85
  ---
86
  """
87
  model_card = f"""
@@ -180,6 +182,13 @@ def parse_args(input_args=None):
180
  required=False,
181
  help="Revision of pretrained model identifier from huggingface.co/models.",
182
  )
 
 
 
 
 
 
 
183
  parser.add_argument(
184
  "--instance_data_dir",
185
  type=str,
@@ -1386,6 +1395,7 @@ def main(args):
1386
  save_model_card(
1387
  repo_id,
1388
  images=images,
 
1389
  base_model=args.pretrained_model_name_or_path,
1390
  train_text_encoder=args.train_text_encoder,
1391
  prompt=args.instance_prompt,
 
64
 
65
 
66
  def save_model_card(
67
+ repo_id: str, images=None, dataset_id=str, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
68
  ):
69
  img_str = ""
70
  for i, image in enumerate(images):
 
82
  - diffusers
83
  - lora
84
  inference: false
85
+ datasets:
86
+ - {dataset_id}
87
  ---
88
  """
89
  model_card = f"""
 
182
  required=False,
183
  help="Revision of pretrained model identifier from huggingface.co/models.",
184
  )
185
+ parser.add_argument(
186
+ "--dataset_id",
187
+ type=str,
188
+ default=None,
189
+ required=True,
190
+ help="The dataset ID you want to train images from",
191
+ )
192
  parser.add_argument(
193
  "--instance_data_dir",
194
  type=str,
 
1395
  save_model_card(
1396
  repo_id,
1397
  images=images,
1398
+ dataset_id=args.dataset_id,
1399
  base_model=args.pretrained_model_name_or_path,
1400
  train_text_encoder=args.train_text_encoder,
1401
  prompt=args.instance_prompt,