Spaces:
Build error
Build error
Update trainer.py
Browse files- trainer.py +17 -9
trainer.py
CHANGED
@@ -60,11 +60,13 @@ class Trainer:
|
|
60 |
resolution_s: str,
|
61 |
concept_images: list | None,
|
62 |
concept_prompt: str,
|
|
|
63 |
n_steps: int,
|
64 |
learning_rate: float,
|
65 |
train_text_encoder: bool,
|
66 |
learning_rate_text: float,
|
67 |
gradient_accumulation: int,
|
|
|
68 |
fp16: bool,
|
69 |
use_8bit_adam: bool,
|
70 |
) -> tuple[dict, list[pathlib.Path]]:
|
@@ -85,18 +87,24 @@ class Trainer:
|
|
85 |
self.prepare_dataset(concept_images, resolution)
|
86 |
|
87 |
command = f'''
|
88 |
-
accelerate launch
|
89 |
-
--pretrained_model_name_or_path={base_model}
|
90 |
-
--instance_data_dir={self.instance_data_dir}
|
|
|
91 |
--output_dir={self.output_dir} \
|
|
|
92 |
--instance_prompt="{concept_prompt}" \
|
93 |
-
--
|
94 |
-
--
|
95 |
-
--
|
96 |
-
--
|
97 |
-
--
|
|
|
98 |
--lr_warmup_steps=0 \
|
99 |
-
--max_train_steps={n_steps}
|
|
|
|
|
|
|
100 |
'''
|
101 |
if fp16:
|
102 |
command += ' --mixed_precision fp16'
|
|
|
60 |
resolution_s: str,
|
61 |
concept_images: list | None,
|
62 |
concept_prompt: str,
|
63 |
+
class_prompt: str,
|
64 |
n_steps: int,
|
65 |
learning_rate: float,
|
66 |
train_text_encoder: bool,
|
67 |
learning_rate_text: float,
|
68 |
gradient_accumulation: int,
|
69 |
+
batch-size: int,
|
70 |
fp16: bool,
|
71 |
use_8bit_adam: bool,
|
72 |
) -> tuple[dict, list[pathlib.Path]]:
|
|
|
87 |
self.prepare_dataset(concept_images, resolution)
|
88 |
|
89 |
command = f'''
|
90 |
+
accelerate launch custom-diffusion/src/diffuser_training.py \
|
91 |
+
--pretrained_model_name_or_path={base_model} \
|
92 |
+
--instance_data_dir={self.instance_data_dir} \
|
93 |
+
--class_data_dir={self.class_data_dir} \
|
94 |
--output_dir={self.output_dir} \
|
95 |
+
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
96 |
--instance_prompt="{concept_prompt}" \
|
97 |
+
--class_prompt="{class_prompt}" \
|
98 |
+
--resolution={resolution} \
|
99 |
+
--train_batch_size={batch-size} \
|
100 |
+
--gradient_accumulation_steps={gradient_accumulation} \
|
101 |
+
--learning_rate={learning_rate} \
|
102 |
+
--lr_scheduler="constant" \
|
103 |
--lr_warmup_steps=0 \
|
104 |
+
--max_train_steps={n_steps} \
|
105 |
+
--num_class_images=200 \
|
106 |
+
--scale_lr \
|
107 |
+
--modifier_token "<new1>"
|
108 |
'''
|
109 |
if fp16:
|
110 |
command += ' --mixed_precision fp16'
|