add dataset_id arg
Browse files
train_dreambooth_lora_sdxl.py
CHANGED
@@ -64,7 +64,7 @@ logger = get_logger(__name__)
|
|
64 |
|
65 |
|
66 |
def save_model_card(
|
67 |
-
repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
|
68 |
):
|
69 |
img_str = ""
|
70 |
for i, image in enumerate(images):
|
@@ -82,6 +82,8 @@ tags:
|
|
82 |
- diffusers
|
83 |
- lora
|
84 |
inference: false
|
|
|
|
|
85 |
---
|
86 |
"""
|
87 |
model_card = f"""
|
@@ -180,6 +182,13 @@ def parse_args(input_args=None):
|
|
180 |
required=False,
|
181 |
help="Revision of pretrained model identifier from huggingface.co/models.",
|
182 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
parser.add_argument(
|
184 |
"--instance_data_dir",
|
185 |
type=str,
|
@@ -1386,6 +1395,7 @@ def main(args):
|
|
1386 |
save_model_card(
|
1387 |
repo_id,
|
1388 |
images=images,
|
|
|
1389 |
base_model=args.pretrained_model_name_or_path,
|
1390 |
train_text_encoder=args.train_text_encoder,
|
1391 |
prompt=args.instance_prompt,
|
|
|
64 |
|
65 |
|
66 |
def save_model_card(
|
67 |
+
repo_id: str, images=None, dataset_id=str, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
|
68 |
):
|
69 |
img_str = ""
|
70 |
for i, image in enumerate(images):
|
|
|
82 |
- diffusers
|
83 |
- lora
|
84 |
inference: false
|
85 |
+
datasets:
|
86 |
+
- {dataset_id}
|
87 |
---
|
88 |
"""
|
89 |
model_card = f"""
|
|
|
182 |
required=False,
|
183 |
help="Revision of pretrained model identifier from huggingface.co/models.",
|
184 |
)
|
185 |
+
parser.add_argument(
|
186 |
+
"--dataset_id",
|
187 |
+
type=str,
|
188 |
+
default=None,
|
189 |
+
required=True,
|
190 |
+
help="The dataset ID you want to train images from",
|
191 |
+
)
|
192 |
parser.add_argument(
|
193 |
"--instance_data_dir",
|
194 |
type=str,
|
|
|
1395 |
save_model_card(
|
1396 |
repo_id,
|
1397 |
images=images,
|
1398 |
+
dataset_id=args.dataset_id,
|
1399 |
base_model=args.pretrained_model_name_or_path,
|
1400 |
train_text_encoder=args.train_text_encoder,
|
1401 |
prompt=args.instance_prompt,
|