ClaireOzzz commited on
Commit
fbb8a30
1 Parent(s): e042083

Upload 22 files

Browse files
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Train Dreambooth LoRa Sd-XL
3
+ emoji: 🏆
4
+ colorFrom: red
5
+ colorTo: red
6
+ python_version: 3.10.12
7
+ sdk: gradio
8
+ sdk_version: 3.44.2
9
+ app_file: app.py
10
+ pinned: false
11
+ suggested_hardware: "a10g-small"
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import shutil
5
+ import requests
6
+ import subprocess
7
+ from subprocess import getoutput
8
+ from huggingface_hub import login, HfFileSystem, snapshot_download, HfApi, create_repo
9
+ from pathlib import Path
10
+ from PIL import Image
11
+
12
+ from app_train import create_training_demo
13
+ from sdxl.app_inference import create_inference_demo
14
+ from depthgltf.app_visualisations import create_visual_demo
15
+
16
+ from transformers import DPTFeatureExtractor, DPTForDepthEstimation
17
+ import numpy as np
18
+ import open3d as o3d
19
+
20
+
21
+ css="""
22
+ #col-container {max-width: 780px; margin-left: auto; margin-right: auto;}
23
+ #upl-dataset-group {background-color: none!important;}
24
+
25
+ div#warning-ready {
26
+ background-color: #ecfdf5;
27
+ padding: 0 10px 5px;
28
+ margin: 20px 0;
29
+ }
30
+ div#warning-ready > .gr-prose > h2, div#warning-ready > .gr-prose > p {
31
+ color: #057857!important;
32
+ }
33
+
34
+ div#warning-duplicate {
35
+ background-color: #ebf5ff;
36
+ padding: 0 10px 5px;
37
+ margin: 20px 0;
38
+ }
39
+
40
+ div#warning-duplicate > .gr-prose > h2, div#warning-duplicate > .gr-prose > p {
41
+ color: #0f4592!important;
42
+ }
43
+
44
+ div#warning-duplicate strong {
45
+ color: #0f4592;
46
+ }
47
+
48
+ p.actions {
49
+ display: flex;
50
+ align-items: center;
51
+ margin: 20px 0;
52
+ }
53
+
54
+ div#warning-duplicate .actions a {
55
+ display: inline-block;
56
+ margin-right: 10px;
57
+ }
58
+
59
+ div#warning-setgpu {
60
+ background-color: #fff4eb;
61
+ padding: 0 10px 5px;
62
+ margin: 20px 0;
63
+ }
64
+
65
+ div#warning-setgpu > .gr-prose > h2, div#warning-setgpu > .gr-prose > p {
66
+ color: #92220f!important;
67
+ }
68
+
69
+ div#warning-setgpu a, div#warning-setgpu b {
70
+ color: #91230f;
71
+ }
72
+
73
+ div#warning-setgpu p.actions > a {
74
+ display: inline-block;
75
+ background: #1f1f23;
76
+ border-radius: 40px;
77
+ padding: 6px 24px;
78
+ color: antiquewhite;
79
+ text-decoration: none;
80
+ font-weight: 600;
81
+ font-size: 1.2em;
82
+ }
83
+
84
+ button#load-dataset-btn{
85
+ min-height: 60px;
86
+ }
87
+ """
88
+
89
+
90
+ with gr.Blocks(css=css) as demo:
91
+
92
+ gr.Markdown("SUTD x SUNS Shop Design Generator")
93
+ with gr.Tab("Training"):
94
+ create_training_demo()
95
+ with gr.Tab("Generation"):
96
+ create_inference_demo()
97
+
98
+ #create_visual_demo();
99
+ with gr.Tab("Visualisation"):
100
+ create_visual_demo();
101
+
102
+
103
+ demo.queue().launch(debug=True, share=True)
104
+
app_train.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import shutil
5
+ import requests
6
+ import subprocess
7
+ from subprocess import getoutput
8
+ from huggingface_hub import login, HfFileSystem, snapshot_download, HfApi, create_repo
9
+
10
+ is_gpu_associated = torch.cuda.is_available()
11
+
12
+ is_shared_ui = False
13
+
14
+ hf_token = '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'
15
+
16
+ fs = HfFileSystem(token=hf_token)
17
+ api = HfApi()
18
+
19
+ if is_gpu_associated:
20
+ gpu_info = getoutput('nvidia-smi')
21
+ if("A10G" in gpu_info):
22
+ which_gpu = "A10G"
23
+ elif("T4" in gpu_info):
24
+ which_gpu = "T4"
25
+ else:
26
+ which_gpu = "CPU"
27
+
28
+ def check_upload_or_no(value):
29
+ if value is True:
30
+ return gr.update(visible=True)
31
+ else:
32
+ return gr.update(visible=False)
33
+
34
+ def load_images_to_dataset(images, dataset_name):
35
+
36
+ if is_shared_ui:
37
+ raise gr.Error("This Space only works in duplicated instances")
38
+
39
+ if dataset_name == "":
40
+ raise gr.Error("You forgot to name your new dataset. ")
41
+
42
+ # Create the directory if it doesn't exist
43
+ my_working_directory = f"my_working_directory_for_{dataset_name}"
44
+ if not os.path.exists(my_working_directory):
45
+ os.makedirs(my_working_directory)
46
+
47
+ # Assuming 'images' is a list of image file paths
48
+ for idx, image in enumerate(images):
49
+ # Get the base file name (without path) from the original location
50
+ image_name = os.path.basename(image.name)
51
+
52
+ # Construct the destination path in the working directory
53
+ destination_path = os.path.join(my_working_directory, image_name)
54
+
55
+ # Copy the image from the original location to the working directory
56
+ shutil.copy(image.name, destination_path)
57
+
58
+ # Print the image name and its corresponding save path
59
+ print(f"Image {idx + 1}: {image_name} copied to {destination_path}")
60
+
61
+ path_to_folder = my_working_directory
62
+ your_username = api.whoami(token=hf_token)["name"]
63
+ repo_id = f"{your_username}/{dataset_name}"
64
+ create_repo(repo_id=repo_id, repo_type="dataset", token=hf_token)
65
+
66
+ api.upload_folder(
67
+ folder_path=path_to_folder,
68
+ repo_id=repo_id,
69
+ repo_type="dataset",
70
+ token=hf_token
71
+ )
72
+
73
+ return "Done, your dataset is ready and loaded for the training step!", repo_id
74
+
75
+ def swap_hardware(hf_token, hardware="cpu-basic"):
76
+ hardware_url = f"https://huggingface.co/spaces/ClaireOzzz/train-dreambooth-lora-sdxl/hardware"
77
+ headers = { "authorization" : f"Bearer {hf_token}"}
78
+ body = {'flavor': hardware}
79
+ requests.post(hardware_url, json = body, headers=headers)
80
+
81
+ def swap_sleep_time(hf_token,sleep_time):
82
+ sleep_time_url = f"https://huggingface.co/api/spaces/ClaireOzzz/train-dreambooth-lora-sdxl/sleeptime"
83
+ headers = { "authorization" : f"Bearer {hf_token}"}
84
+ body = {'seconds':sleep_time}
85
+ requests.post(sleep_time_url,json=body,headers=headers)
86
+
87
+ def get_sleep_time(hf_token):
88
+ sleep_time_url = f"https://huggingface.co/api/spaces/ClaireOzzz/train-dreambooth-lora-sdxl"
89
+ headers = { "authorization" : f"Bearer {hf_token}"}
90
+ response = requests.get(sleep_time_url,headers=headers)
91
+ try:
92
+ gcTimeout = response.json()['runtime']['gcTimeout']
93
+ except:
94
+ gcTimeout = None
95
+ return gcTimeout
96
+
97
+ def write_to_community(title, description,hf_token):
98
+
99
+ api.create_discussion(repo_id=os.environ['ClaireOzzz/train-dreambooth-lora-sdxl'], title=title, description=description,repo_type="space", token=hf_token)
100
+
101
+
102
+ def set_accelerate_default_config():
103
+ try:
104
+ subprocess.run(["accelerate", "config", "default"], check=True)
105
+ print("Accelerate default config set successfully!")
106
+ except subprocess.CalledProcessError as e:
107
+ print(f"An error occurred: {e}")
108
+
109
+ def train_dreambooth_lora_sdxl(dataset_id, instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps, remove_gpu):
110
+
111
+ script_filename = "train_dreambooth_lora_sdxl.py" # Assuming it's in the same folder
112
+
113
+ command = [
114
+ "accelerate",
115
+ "launch",
116
+ script_filename, # Use the local script
117
+ "--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
118
+ "--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
119
+ f"--dataset_id={dataset_id}",
120
+ f"--instance_data_dir={instance_data_dir}",
121
+ f"--output_dir={lora_trained_xl_folder}",
122
+ "--mixed_precision=fp16",
123
+ f"--instance_prompt={instance_prompt}",
124
+ "--resolution=1024",
125
+ "--train_batch_size=2",
126
+ "--gradient_accumulation_steps=2",
127
+ "--gradient_checkpointing",
128
+ "--learning_rate=1e-4",
129
+ "--lr_scheduler=constant",
130
+ "--lr_warmup_steps=0",
131
+ "--enable_xformers_memory_efficient_attention",
132
+ "--mixed_precision=fp16",
133
+ "--use_8bit_adam",
134
+ f"--max_train_steps={max_train_steps}",
135
+ f"--checkpointing_steps={checkpoint_steps}",
136
+ "--seed=0",
137
+ "--push_to_hub",
138
+ f"--hub_token={hf_token}"
139
+ ]
140
+
141
+ try:
142
+ subprocess.run(command, check=True)
143
+ print("Training is finished!")
144
+ if remove_gpu:
145
+ swap_hardware(hf_token, "cpu-basic")
146
+ else:
147
+ swap_sleep_time(hf_token, 300)
148
+ except subprocess.CalledProcessError as e:
149
+ print(f"An error occurred: {e}")
150
+
151
+ title="There was an error on during your training"
152
+ description=f'''
153
+ Unfortunately there was an error during training your {lora_trained_xl_folder} model.
154
+ Please check it out below. Feel free to report this issue to [SD-XL Dreambooth LoRa Training](https://huggingface.co/spaces/fffiloni/train-dreambooth-lora-sdxl):
155
+ ```
156
+ {str(e)}
157
+ ```
158
+ '''
159
+ if remove_gpu:
160
+ swap_hardware(hf_token, "cpu-basic")
161
+ else:
162
+ swap_sleep_time(hf_token, 300)
163
+ #write_to_community(title,description,hf_token)
164
+
165
+ def main(dataset_id,
166
+ lora_trained_xl_folder,
167
+ instance_prompt,
168
+ max_train_steps,
169
+ checkpoint_steps,
170
+ remove_gpu):
171
+
172
+
173
+ if is_shared_ui:
174
+ raise gr.Error("This Space only works in duplicated instances")
175
+
176
+ if not is_gpu_associated:
177
+ raise gr.Error("Please associate a T4 or A10G GPU for this Space")
178
+
179
+ if dataset_id == "":
180
+ raise gr.Error("You forgot to specify an image dataset")
181
+
182
+ if instance_prompt == "":
183
+ raise gr.Error("You forgot to specify a concept prompt")
184
+
185
+ if lora_trained_xl_folder == "":
186
+ raise gr.Error("You forgot to name the output folder for your model")
187
+
188
+ sleep_time = get_sleep_time(hf_token)
189
+ if sleep_time:
190
+ swap_sleep_time(hf_token, -1)
191
+
192
+ gr.Warning("If you did not check the `Remove GPU After training`, don't forget to remove the GPU attribution after you are done. ")
193
+
194
+ dataset_repo = dataset_id
195
+
196
+ # Automatically set local_dir based on the last part of dataset_repo
197
+ repo_parts = dataset_repo.split("/")
198
+ local_dir = f"./{repo_parts[-1]}" # Use the last part of the split
199
+
200
+ # Check if the directory exists and create it if necessary
201
+ if not os.path.exists(local_dir):
202
+ os.makedirs(local_dir)
203
+
204
+ gr.Info("Downloading dataset ...")
205
+
206
+ snapshot_download(
207
+ dataset_repo,
208
+ local_dir=local_dir,
209
+ repo_type="dataset",
210
+ ignore_patterns=".gitattributes",
211
+ token=hf_token
212
+ )
213
+
214
+ set_accelerate_default_config()
215
+
216
+ gr.Info("Training begins ...")
217
+
218
+ instance_data_dir = repo_parts[-1]
219
+ train_dreambooth_lora_sdxl(dataset_id, instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps, remove_gpu)
220
+
221
+ your_username = api.whoami(token=hf_token)["name"]
222
+ return f"Done, your trained model has been stored in your models library: {your_username}/{lora_trained_xl_folder}"
223
+
224
+ css="""
225
+ #col-container {max-width: 780px; margin-left: auto; margin-right: auto;}
226
+ #upl-dataset-group {background-color: none!important;}
227
+
228
+ div#warning-ready {
229
+ background-color: #ecfdf5;
230
+ padding: 0 10px 5px;
231
+ margin: 20px 0;
232
+ }
233
+ div#warning-ready > .gr-prose > h2, div#warning-ready > .gr-prose > p {
234
+ color: #057857!important;
235
+ }
236
+
237
+ div#warning-duplicate {
238
+ background-color: #ebf5ff;
239
+ padding: 0 10px 5px;
240
+ margin: 20px 0;
241
+ }
242
+
243
+ div#warning-duplicate > .gr-prose > h2, div#warning-duplicate > .gr-prose > p {
244
+ color: #0f4592!important;
245
+ }
246
+
247
+ div#warning-duplicate strong {
248
+ color: #0f4592;
249
+ }
250
+
251
+ p.actions {
252
+ display: flex;
253
+ align-items: center;
254
+ margin: 20px 0;
255
+ }
256
+
257
+ div#warning-duplicate .actions a {
258
+ display: inline-block;
259
+ margin-right: 10px;
260
+ }
261
+
262
+ div#warning-setgpu {
263
+ background-color: #fff4eb;
264
+ padding: 0 10px 5px;
265
+ margin: 20px 0;
266
+ }
267
+
268
+ div#warning-setgpu > .gr-prose > h2, div#warning-setgpu > .gr-prose > p {
269
+ color: #92220f!important;
270
+ }
271
+
272
+ div#warning-setgpu a, div#warning-setgpu b {
273
+ color: #91230f;
274
+ }
275
+
276
+ div#warning-setgpu p.actions > a {
277
+ display: inline-block;
278
+ background: #1f1f23;
279
+ border-radius: 40px;
280
+ padding: 6px 24px;
281
+ color: antiquewhite;
282
+ text-decoration: none;
283
+ font-weight: 600;
284
+ font-size: 1.2em;
285
+ }
286
+
287
+ button#load-dataset-btn{
288
+ min-height: 60px;
289
+ }
290
+ """
291
+ def create_training_demo() -> gr.Blocks:
292
+ with gr.Blocks(css=css) as demo:
293
+ with gr.Column(elem_id="col-container"):
294
+ if is_shared_ui:
295
+ top_description = gr.HTML(f'''
296
+ <div class="gr-prose">
297
+ <h2><svg xmlns="http://www.w3.org/2000/svg" width="18px" height="18px" style="margin-right: 0px;display: inline-block;"fill="none"><path fill="#fff" d="M7 13.2a6.3 6.3 0 0 0 4.4-10.7A6.3 6.3 0 0 0 .6 6.9 6.3 6.3 0 0 0 7 13.2Z"/><path fill="#fff" fill-rule="evenodd" d="M7 0a6.9 6.9 0 0 1 4.8 11.8A6.9 6.9 0 0 1 0 7 6.9 6.9 0 0 1 7 0Zm0 0v.7V0ZM0 7h.6H0Zm7 6.8v-.6.6ZM13.7 7h-.6.6ZM9.1 1.7c-.7-.3-1.4-.4-2.2-.4a5.6 5.6 0 0 0-4 1.6 5.6 5.6 0 0 0-1.6 4 5.6 5.6 0 0 0 1.6 4 5.6 5.6 0 0 0 4 1.7 5.6 5.6 0 0 0 4-1.7 5.6 5.6 0 0 0 1.7-4 5.6 5.6 0 0 0-1.7-4c-.5-.5-1.1-.9-1.8-1.2Z" clip-rule="evenodd"/><path fill="#000" fill-rule="evenodd" d="M7 2.9a.8.8 0 1 1 0 1.5A.8.8 0 0 1 7 3ZM5.8 5.7c0-.4.3-.6.6-.6h.7c.3 0 .6.2.6.6v3.7h.5a.6.6 0 0 1 0 1.3H6a.6.6 0 0 1 0-1.3h.4v-3a.6.6 0 0 1-.6-.7Z" clip-rule="evenodd"/></svg>
298
+ Attention: this Space need to be duplicated to work</h2>
299
+ <p class="main-message">
300
+ To make it work, <strong>duplicate the Space</strong> and run it on your own profile using a <strong>private</strong> GPU (T4-small or A10G-small).<br />
301
+ A T4 costs <strong>US$0.60/h</strong>, so it should cost < US$1 to train most models.
302
+ </p>
303
+ <p class="actions">
304
+
305
+ to start training your own image model
306
+ </p>
307
+ </div>
308
+ ''', elem_id="warning-duplicate")
309
+ # else:
310
+ # if(is_gpu_associated):
311
+ # top_description = gr.HTML(f'''
312
+ # <div class="gr-prose">
313
+ # <h2><svg xmlns="http://www.w3.org/2000/svg" width="18px" height="18px" style="margin-right: 0px;display: inline-block;"fill="none"><path fill="#fff" d="M7 13.2a6.3 6.3 0 0 0 4.4-10.7A6.3 6.3 0 0 0 .6 6.9 6.3 6.3 0 0 0 7 13.2Z"/><path fill="#fff" fill-rule="evenodd" d="M7 0a6.9 6.9 0 0 1 4.8 11.8A6.9 6.9 0 0 1 0 7 6.9 6.9 0 0 1 7 0Zm0 0v.7V0ZM0 7h.6H0Zm7 6.8v-.6.6ZM13.7 7h-.6.6ZM9.1 1.7c-.7-.3-1.4-.4-2.2-.4a5.6 5.6 0 0 0-4 1.6 5.6 5.6 0 0 0-1.6 4 5.6 5.6 0 0 0 1.6 4 5.6 5.6 0 0 0 4 1.7 5.6 5.6 0 0 0 4-1.7 5.6 5.6 0 0 0 1.7-4 5.6 5.6 0 0 0-1.7-4c-.5-.5-1.1-.9-1.8-1.2Z" clip-rule="evenodd"/><path fill="#000" fill-rule="evenodd" d="M7 2.9a.8.8 0 1 1 0 1.5A.8.8 0 0 1 7 3ZM5.8 5.7c0-.4.3-.6.6-.6h.7c.3 0 .6.2.6.6v3.7h.5a.6.6 0 0 1 0 1.3H6a.6.6 0 0 1 0-1.3h.4v-3a.6.6 0 0 1-.6-.7Z" clip-rule="evenodd"/></svg>
314
+ # You have successfully associated a {which_gpu} GPU to the SD-XL Training Space 🎉</h2>
315
+ # <p>
316
+ # You can now train your model! You will be billed by the minute from when you activated the GPU until when it is turned off.
317
+ # </p>
318
+ # </div>
319
+ # ''', elem_id="warning-ready")
320
+ # else:
321
+ # top_description = gr.HTML(f'''
322
+ # <div class="gr-prose">
323
+ # <h2><svg xmlns="http://www.w3.org/2000/svg" width="18px" height="18px" style="margin-right: 0px;display: inline-block;"fill="none"><path fill="#fff" d="M7 13.2a6.3 6.3 0 0 0 4.4-10.7A6.3 6.3 0 0 0 .6 6.9 6.3 6.3 0 0 0 7 13.2Z"/><path fill="#fff" fill-rule="evenodd" d="M7 0a6.9 6.9 0 0 1 4.8 11.8A6.9 6.9 0 0 1 0 7 6.9 6.9 0 0 1 7 0Zm0 0v.7V0ZM0 7h.6H0Zm7 6.8v-.6.6ZM13.7 7h-.6.6ZM9.1 1.7c-.7-.3-1.4-.4-2.2-.4a5.6 5.6 0 0 0-4 1.6 5.6 5.6 0 0 0-1.6 4 5.6 5.6 0 0 0 1.6 4 5.6 5.6 0 0 0 4 1.7 5.6 5.6 0 0 0 4-1.7 5.6 5.6 0 0 0 1.7-4 5.6 5.6 0 0 0-1.7-4c-.5-.5-1.1-.9-1.8-1.2Z" clip-rule="evenodd"/><path fill="#000" fill-rule="evenodd" d="M7 2.9a.8.8 0 1 1 0 1.5A.8.8 0 0 1 7 3ZM5.8 5.7c0-.4.3-.6.6-.6h.7c.3 0 .6.2.6.6v3.7h.5a.6.6 0 0 1 0 1.3H6a.6.6 0 0 1 0-1.3h.4v-3a.6.6 0 0 1-.6-.7Z" clip-rule="evenodd"/></svg>
324
+ # You have successfully duplicated the SD-XL Training Space 🎉</h2>
325
+ # <p>There's only one step left before you can train your model: <a href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}/settings" style="text-decoration: underline" target="_blank">attribute a <b>T4-small or A10G-small GPU</b> to it (via the Settings tab)</a> and run the training below.
326
+ # You will be billed by the minute from when you activate the GPU until when it is turned off.</p>
327
+ # <p class="actions">
328
+ # <a href="https://huggingface.co/spaces/ClaireOzzz/train-dreambooth-lora-sdxl/settings">🔥 &nbsp; Set recommended GPU</a>
329
+ # </p>
330
+ # </div>
331
+ # ''', elem_id="warning-setgpu")
332
+
333
+ gr.Markdown("# SD-XL Dreambooth LoRa Training UI 💭")
334
+
335
+ upload_my_images = gr.Checkbox(label="Drop your training images ? (optional)", value=False)
336
+ gr.Markdown("Use this step to upload your training images and create a new dataset. If you already have a dataset stored on your HF profile, you can skip this step, and provide your dataset ID in the training `Datased ID` input below.")
337
+
338
+ with gr.Group(visible=False, elem_id="upl-dataset-group") as upload_group:
339
+ with gr.Row():
340
+ images = gr.File(file_types=["image"], label="Upload your images", file_count="multiple", interactive=True, visible=True)
341
+ with gr.Column():
342
+ new_dataset_name = gr.Textbox(label="Set new dataset name", placeholder="e.g.: my_awesome_dataset")
343
+ dataset_status = gr.Textbox(label="dataset status")
344
+ load_btn = gr.Button("Load images to new dataset", elem_id="load-dataset-btn")
345
+
346
+ gr.Markdown("## Training ")
347
+ gr.Markdown("You can use an existing image dataset, find a dataset example here: [https://huggingface.co/datasets/diffusers/dog-example](https://huggingface.co/datasets/diffusers/dog-example) ;)")
348
+
349
+ with gr.Row():
350
+ dataset_id = gr.Textbox(label="Dataset ID", info="use one of your previously uploaded image datasets on your HF profile", placeholder="diffusers/dog-example")
351
+ instance_prompt = gr.Textbox(label="Concept prompt", info="concept prompt - use a unique, made up word to avoid collisions")
352
+
353
+ with gr.Row():
354
+ model_output_folder = gr.Textbox(label="Output model folder name", placeholder="lora-trained-xl-folder")
355
+ max_train_steps = gr.Number(label="Max Training Steps", value=500, precision=0, step=10)
356
+ checkpoint_steps = gr.Number(label="Checkpoints Steps", value=100, precision=0, step=10)
357
+
358
+ remove_gpu = gr.Checkbox(label="Remove GPU After Training", value=True, info="If NOT enabled, don't forget to remove the GPU attribution after you are done.")
359
+ train_button = gr.Button("Train !")
360
+
361
+ train_status = gr.Textbox(label="Training status")
362
+
363
+ upload_my_images.change(
364
+ fn = check_upload_or_no,
365
+ inputs =[upload_my_images],
366
+ outputs = [upload_group]
367
+ )
368
+
369
+ load_btn.click(
370
+ fn = load_images_to_dataset,
371
+ inputs = [images, new_dataset_name],
372
+ outputs = [dataset_status, dataset_id]
373
+ )
374
+
375
+ train_button.click(
376
+ fn = main,
377
+ inputs = [
378
+ dataset_id,
379
+ model_output_folder,
380
+ instance_prompt,
381
+ max_train_steps,
382
+ checkpoint_steps,
383
+ remove_gpu
384
+ ],
385
+ outputs = [train_status]
386
+ )
387
+ return demo
388
+
389
+ #demo.launch(debug=True, share=True)
depthgltf/.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
depthgltf/.gitignore ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python build
2
+ .eggs/
3
+ gradio.egg-info/*
4
+ !gradio.egg-info/requires.txt
5
+ !gradio.egg-info/PKG-INFO
6
+ dist/
7
+ *.pyc
8
+ __pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+ build/
12
+
13
+ # JS build
14
+ gradio/templates/frontend
15
+ # Secrets
16
+ .env
17
+
18
+ # Gradio run artifacts
19
+ *.db
20
+ *.sqlite3
21
+ gradio/launches.json
22
+ flagged/
23
+ gradio_cached_examples/
24
+
25
+ # Tests
26
+ .coverage
27
+ coverage.xml
28
+ test.txt
29
+
30
+ # Demos
31
+ demo/tmp.zip
32
+ demo/files/*.avi
33
+ demo/files/*.mp4
34
+
35
+ # Etc
36
+ .idea/*
37
+ .DS_Store
38
+ *.bak
39
+ workspace.code-workspace
40
+ *.h5
41
+ .vscode/
42
+
43
+ # log files
44
+ .pnpm-debug.log
45
+ venv/
46
+ *.db-journal
depthgltf/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Dpt Depth Estimation + 3D
3
+ emoji: ⚡
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.0b8
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
depthgltf/app_visualisations.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import DPTFeatureExtractor, DPTForDepthEstimation
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import open3d as o3d
7
+ from pathlib import Path
8
+ import os
9
+
10
+ feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
11
+ model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
12
+
13
+
14
+ def process_image(image_path):
15
+ image_path = Path(image_path)
16
+ image_raw = Image.open(image_path)
17
+ image = image_raw.resize(
18
+ (800, int(800 * image_raw.size[1] / image_raw.size[0])),
19
+ Image.Resampling.LANCZOS)
20
+
21
+ # prepare image for the model
22
+ encoding = feature_extractor(image, return_tensors="pt")
23
+
24
+ # forward pass
25
+ with torch.no_grad():
26
+ outputs = model(**encoding)
27
+ predicted_depth = outputs.predicted_depth
28
+
29
+ # interpolate to original size
30
+ prediction = torch.nn.functional.interpolate(
31
+ predicted_depth.unsqueeze(1),
32
+ size=image.size[::-1],
33
+ mode="bicubic",
34
+ align_corners=False,
35
+ ).squeeze()
36
+ output = prediction.cpu().numpy()
37
+ depth_image = (output * 255 / np.max(output)).astype('uint8')
38
+ try:
39
+ gltf_path = create_3d_obj(np.array(image), depth_image, image_path)
40
+ img = Image.fromarray(depth_image)
41
+ return [img, gltf_path, gltf_path]
42
+ except Exception as e:
43
+ gltf_path = create_3d_obj(
44
+ np.array(image), depth_image, image_path, depth=8)
45
+ img = Image.fromarray(depth_image)
46
+ return [img, gltf_path, gltf_path]
47
+ except:
48
+ print("Error reconstructing 3D model")
49
+ raise Exception("Error reconstructing 3D model")
50
+
51
+
52
+ def create_3d_obj(rgb_image, depth_image, image_path, depth=10):
53
+ depth_o3d = o3d.geometry.Image(depth_image)
54
+ image_o3d = o3d.geometry.Image(rgb_image)
55
+ rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
56
+ image_o3d, depth_o3d, convert_rgb_to_intensity=False)
57
+ w = int(depth_image.shape[1])
58
+ h = int(depth_image.shape[0])
59
+
60
+ camera_intrinsic = o3d.camera.PinholeCameraIntrinsic()
61
+ camera_intrinsic.set_intrinsics(w, h, 500, 500, w/2, h/2)
62
+
63
+ pcd = o3d.geometry.PointCloud.create_from_rgbd_image(
64
+ rgbd_image, camera_intrinsic)
65
+
66
+ print('normals')
67
+ pcd.normals = o3d.utility.Vector3dVector(
68
+ np.zeros((1, 3))) # invalidate existing normals
69
+ pcd.estimate_normals(
70
+ search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.01, max_nn=30))
71
+ pcd.orient_normals_towards_camera_location(
72
+ camera_location=np.array([0., 0., 1000.]))
73
+ pcd.transform([[1, 0, 0, 0],
74
+ [0, -1, 0, 0],
75
+ [0, 0, -1, 0],
76
+ [0, 0, 0, 1]])
77
+ pcd.transform([[-1, 0, 0, 0],
78
+ [0, 1, 0, 0],
79
+ [0, 0, 1, 0],
80
+ [0, 0, 0, 1]])
81
+
82
+ print('run Poisson surface reconstruction')
83
+ with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
84
+ mesh_raw, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
85
+ pcd, depth=depth, width=0, scale=1.1, linear_fit=True)
86
+
87
+ voxel_size = max(mesh_raw.get_max_bound() - mesh_raw.get_min_bound()) / 256
88
+ print(f'voxel_size = {voxel_size:e}')
89
+ mesh = mesh_raw.simplify_vertex_clustering(
90
+ voxel_size=voxel_size,
91
+ contraction=o3d.geometry.SimplificationContraction.Average)
92
+
93
+ # vertices_to_remove = densities < np.quantile(densities, 0.001)
94
+ # mesh.remove_vertices_by_mask(vertices_to_remove)
95
+ bbox = pcd.get_axis_aligned_bounding_box()
96
+ mesh_crop = mesh.crop(bbox)
97
+ gltf_path = f'./{image_path.stem}.gltf'
98
+ o3d.io.write_triangle_mesh(
99
+ gltf_path, mesh_crop, write_triangle_uvs=True)
100
+ return gltf_path
101
+
102
+
103
+ current_directory = os.path.dirname(__file__)
104
+
105
+ title = "Demo: zero-shot depth estimation with DPT + 3D Point Cloud"
106
+ description = "This demo is a variation from the original <a href='https://huggingface.co/spaces/nielsr/dpt-depth-estimation' target='_blank'>DPT Demo</a>. It uses the DPT model to predict the depth of an image and then uses 3D Point Cloud to create a 3D object."
107
+ #examples = [["examples/" + img] for img in os.listdir("examples/")]
108
+
109
+ # result_image_path = os.path.join(current_directory, '..', 'result.png')
110
+ # image_path = Path(result_image_path)
111
+
112
+
113
+ # Load the image
114
+ # rawimage = Image.open(image_path)
115
+ # image_r = gr.Image(value=rawimage, type="pil", label="Input Image")
116
+ #image_r.change(create_visual_demo, [],[])
117
+
118
+ def create_visual_demo():
119
+ iface = gr.Interface(fn=process_image,
120
+ inputs=[gr.Image(
121
+ type="filepath", label="Input Image")],
122
+ outputs=[gr.Image(label="predicted depth", type="pil"),
123
+ gr.Model3D(label="3d mesh reconstruction", clear_color=[
124
+ 1.0, 1.0, 1.0, 1.0]),
125
+ gr.File(label="3d gLTF")],
126
+ title=title,
127
+ description=description,
128
+ #examples=examples,
129
+ live=True,
130
+ allow_flagging="never",
131
+ cache_examples=False)
132
+
133
+ #iface.launch(debug=True, enable_queue=False, share=True)
depthgltf/examples/1-jonathan-borba-CgWTqYxHEkg-unsplash.jpg ADDED
depthgltf/examples/2-ronan-furuta-cvM7AC22dSI-unsplash.jpg ADDED
depthgltf/examples/3-artem-beliaikin-vyxOD0NuJbs-unsplash.jpg ADDED
depthgltf/examples/alisa-anton-PXgXLgDPv6w-unsplash.jpg ADDED
depthgltf/examples/joel-muniz-KodMXENNaas-unsplash.jpg ADDED
depthgltf/examples/opollo-photography-nxy9wFUiksg-unsplash.jpg ADDED
depthgltf/examples/zeynep-sumer-HE2nWVI62BY-unsplash.jpg ADDED
depthgltf/packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ libgl1-mesa-glx
depthgltf/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ git+https://github.com/nielsrogge/transformers.git@add_dpt_redesign#egg=transformers
3
+ numpy
4
+ Pillow
5
+ gradio
6
+ jinja2
7
+ open3d
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision==0.15.2
3
+ xformers
4
+ bitsandbytes
5
+ transformers
6
+ accelerate
7
+ git+https://github.com/huggingface/diffusers.git
8
+ scipy
9
+ huggingface_hub
10
+ datasets
sdxl/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
sdxl/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SD-XL + Control LoRas
3
+ emoji: 🦀
4
+ colorFrom: green
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.44.4
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
sdxl/app_inference.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import login, HfFileSystem, HfApi, ModelCard
3
+ import os
4
+ import spaces
5
+ import random
6
+ import torch
7
+
8
+ is_shared_ui = False
9
+
10
+ hf_token = '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'
11
+ login(token=hf_token)
12
+
13
+ fs = HfFileSystem(token=hf_token)
14
+ api = HfApi()
15
+
16
+ device="cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
19
+ from diffusers.utils import load_image
20
+ from PIL import Image
21
+ import torch
22
+ import numpy as np
23
+ import cv2
24
+
25
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
26
+
27
+ controlnet = ControlNetModel.from_pretrained(
28
+ "diffusers/controlnet-canny-sdxl-1.0",
29
+ torch_dtype=torch.float16
30
+ )
31
+
32
+ def check_use_custom_or_no(value):
33
+ if value is True:
34
+ return gr.update(visible=True)
35
+ else:
36
+ return gr.update(visible=False)
37
+
38
+ def get_files(file_paths):
39
+ last_files = {} # Dictionary to store the last file for each path
40
+
41
+ for file_path in file_paths:
42
+ # Split the file path into directory and file components
43
+ directory, file_name = file_path.rsplit('/', 1)
44
+
45
+ # Update the last file for the current path
46
+ last_files[directory] = file_name
47
+
48
+ # Extract the last files from the dictionary
49
+ result = list(last_files.values())
50
+
51
+ return result
52
+
53
+ def load_model(model_name):
54
+
55
+ if model_name == "":
56
+ gr.Warning("If you want to use a private model, you need to duplicate this space on your personal account.")
57
+ raise gr.Error("You forgot to define Model ID.")
58
+
59
+ # Get instance_prompt a.k.a trigger word
60
+ card = ModelCard.load(model_name)
61
+ repo_data = card.data.to_dict()
62
+ instance_prompt = repo_data.get("instance_prompt")
63
+
64
+ if instance_prompt is not None:
65
+ print(f"Trigger word: {instance_prompt}")
66
+ else:
67
+ instance_prompt = "no trigger word needed"
68
+ print(f"Trigger word: no trigger word needed")
69
+
70
+ # List all ".safetensors" files in repo
71
+ sfts_available_files = fs.glob(f"{model_name}/*safetensors")
72
+ sfts_available_files = get_files(sfts_available_files)
73
+
74
+ if sfts_available_files == []:
75
+ sfts_available_files = ["NO SAFETENSORS FILE"]
76
+
77
+ print(f"Safetensors available: {sfts_available_files}")
78
+
79
+ return model_name, "Model Ready", gr.update(choices=sfts_available_files, value=sfts_available_files[0], visible=True), gr.update(value=instance_prompt, visible=True)
80
+
81
+ def custom_model_changed(model_name, previous_model):
82
+ if model_name == "" and previous_model == "" :
83
+ status_message = ""
84
+ elif model_name != previous_model:
85
+ status_message = "model changed, please reload before any new run"
86
+ else:
87
+ status_message = "model ready"
88
+ return status_message
89
+
90
+ def resize_image(input_path, output_path, target_height):
91
+ # Open the input image
92
+ img = Image.open(input_path)
93
+
94
+ # Calculate the aspect ratio of the original image
95
+ original_width, original_height = img.size
96
+ original_aspect_ratio = original_width / original_height
97
+
98
+ # Calculate the new width while maintaining the aspect ratio and the target height
99
+ new_width = int(target_height * original_aspect_ratio)
100
+
101
+ # Resize the image while maintaining the aspect ratio and fixing the height
102
+ img = img.resize((new_width, target_height), Image.LANCZOS)
103
+
104
+ # Save the resized image
105
+ img.save(output_path)
106
+
107
+ return output_path
108
+
109
+ @spaces.GPU
110
+ def infer(use_custom_model, model_name, weight_name, custom_lora_weight, image_in, prompt, negative_prompt, preprocessor, controlnet_conditioning_scale, guidance_scale, inf_steps, seed, progress=gr.Progress(track_tqdm=True)):
111
+
112
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
113
+ "stabilityai/stable-diffusion-xl-base-1.0",
114
+ controlnet=controlnet,
115
+ vae=vae,
116
+ torch_dtype=torch.float16,
117
+ variant="fp16",
118
+ use_safetensors=True
119
+ )
120
+
121
+ pipe.to(device)
122
+
123
+ prompt = prompt
124
+ negative_prompt = negative_prompt
125
+
126
+ if seed < 0 :
127
+ seed = random.randint(0, 423538377342)
128
+
129
+ generator = torch.Generator(device=device).manual_seed(seed)
130
+
131
+ if image_in == None:
132
+ raise gr.Error("You forgot to upload a source image.")
133
+
134
+ image_in = resize_image(image_in, "resized_input.jpg", 1024)
135
+
136
+ if preprocessor == "canny":
137
+
138
+ image = load_image(image_in)
139
+
140
+ image = np.array(image)
141
+ image = cv2.Canny(image, 100, 200)
142
+ image = image[:, :, None]
143
+ image = np.concatenate([image, image, image], axis=2)
144
+ image = Image.fromarray(image)
145
+
146
+ if use_custom_model:
147
+
148
+ if model_name == "":
149
+ raise gr.Error("you forgot to set a custom model name.")
150
+
151
+ custom_model = model_name
152
+
153
+ # This is where you load your trained weights
154
+ if weight_name == "NO SAFETENSORS FILE":
155
+ pipe.load_lora_weights(
156
+ custom_model,
157
+ low_cpu_mem_usage = True,
158
+ use_auth_token = True
159
+ )
160
+
161
+ else:
162
+ pipe.load_lora_weights(
163
+ custom_model,
164
+ weight_name = weight_name,
165
+ low_cpu_mem_usage = True,
166
+ use_auth_token = True
167
+ )
168
+
169
+ lora_scale=custom_lora_weight
170
+
171
+ images = pipe(
172
+ prompt,
173
+ negative_prompt=negative_prompt,
174
+ image=image,
175
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale),
176
+ guidance_scale = float(guidance_scale),
177
+ num_inference_steps=inf_steps,
178
+ generator=generator,
179
+ cross_attention_kwargs={"scale": lora_scale}
180
+ ).images
181
+ else:
182
+ images = pipe(
183
+ prompt,
184
+ negative_prompt=negative_prompt,
185
+ image=image,
186
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale),
187
+ guidance_scale = float(guidance_scale),
188
+ num_inference_steps=inf_steps,
189
+ generator=generator,
190
+ ).images
191
+
192
+ images[0].save(f"result.png")
193
+
194
+ return f"result.png", seed
195
+
196
+ css="""
197
+ #col-container{
198
+ margin: 0 auto;
199
+ max-width: 720px;
200
+ text-align: left;
201
+ }
202
+ div#warning-duplicate {
203
+ background-color: #ebf5ff;
204
+ padding: 0 10px 5px;
205
+ margin: 20px 0;
206
+ }
207
+ div#warning-duplicate > .gr-prose > h2, div#warning-duplicate > .gr-prose > p {
208
+ color: #0f4592!important;
209
+ }
210
+ div#warning-duplicate strong {
211
+ color: #0f4592;
212
+ }
213
+ p.actions {
214
+ display: flex;
215
+ align-items: center;
216
+ margin: 20px 0;
217
+ }
218
+ div#warning-duplicate .actions a {
219
+ display: inline-block;
220
+ margin-right: 10px;
221
+ }
222
+ button#load_model_btn{
223
+ height: 46px;
224
+ }
225
+ #status_info{
226
+ font-size: 0.9em;
227
+ }
228
+ """
229
+ def create_inference_demo() -> gr.Blocks:
230
+
231
+ with gr.Blocks(css=css) as demo:
232
+ with gr.Column(elem_id="col-container"):
233
+ if is_shared_ui:
234
+ top_description = gr.HTML(f'''
235
+ <div class="gr-prose">
236
+ <h2><svg xmlns="http://www.w3.org/2000/svg" width="18px" height="18px" style="margin-right: 0px;display: inline-block;"fill="none"><path fill="#fff" d="M7 13.2a6.3 6.3 0 0 0 4.4-10.7A6.3 6.3 0 0 0 .6 6.9 6.3 6.3 0 0 0 7 13.2Z"/><path fill="#fff" fill-rule="evenodd" d="M7 0a6.9 6.9 0 0 1 4.8 11.8A6.9 6.9 0 0 1 0 7 6.9 6.9 0 0 1 7 0Zm0 0v.7V0ZM0 7h.6H0Zm7 6.8v-.6.6ZM13.7 7h-.6.6ZM9.1 1.7c-.7-.3-1.4-.4-2.2-.4a5.6 5.6 0 0 0-4 1.6 5.6 5.6 0 0 0-1.6 4 5.6 5.6 0 0 0 1.6 4 5.6 5.6 0 0 0 4 1.7 5.6 5.6 0 0 0 4-1.7 5.6 5.6 0 0 0 1.7-4 5.6 5.6 0 0 0-1.7-4c-.5-.5-1.1-.9-1.8-1.2Z" clip-rule="evenodd"/><path fill="#000" fill-rule="evenodd" d="M7 2.9a.8.8 0 1 1 0 1.5A.8.8 0 0 1 7 3ZM5.8 5.7c0-.4.3-.6.6-.6h.7c.3 0 .6.2.6.6v3.7h.5a.6.6 0 0 1 0 1.3H6a.6.6 0 0 1 0-1.3h.4v-3a.6.6 0 0 1-.6-.7Z" clip-rule="evenodd"/></svg>
237
+ Note: you might want to use a <strong>private</strong> custom LoRa model</h2>
238
+ <p class="main-message">
239
+ To do so, <strong>duplicate the Space</strong> and run it on your own profile using <strong>your own access token</strong> and eventually a GPU (T4-small or A10G-small) for faster inference without waiting in the queue.<br />
240
+ </p>
241
+ <p class="actions">
242
+ <a href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}?duplicate=true">
243
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg-dark.svg" alt="Duplicate this Space" />
244
+ </a>
245
+ to start using private models and skip the queue
246
+ </p>
247
+ </div>
248
+ ''', elem_id="warning-duplicate")
249
+ gr.HTML("""
250
+ <h2 style="text-align: center;">SD-XL Control LoRas</h2>
251
+ <p style="text-align: center;">Use StableDiffusion XL with <a href="https://huggingface.co/collections/diffusers/sdxl-controlnets-64f9c35846f3f06f5abe351f">Diffusers' SDXL ControlNets</a></p>
252
+
253
+ """)
254
+
255
+ use_custom_model = gr.Checkbox(label="Use a custom pre-trained LoRa model ? (optional)", value=False, info="To use a private model, you'll need to duplicate the space with your own access token.")
256
+
257
+ with gr.Box(visible=False) as custom_model_box:
258
+ with gr.Row():
259
+ with gr.Column():
260
+ if not is_shared_ui:
261
+ your_username = api.whoami()["name"]
262
+ my_models = api.list_models(author=your_username, filter=["diffusers", "stable-diffusion-xl", 'lora'])
263
+ model_names = [item.modelId for item in my_models]
264
+
265
+ if not is_shared_ui:
266
+ custom_model = gr.Dropdown(
267
+ label = "Your custom model ID",
268
+ info="You can pick one of your private models",
269
+ choices = model_names,
270
+ allow_custom_value = True
271
+ #placeholder = "username/model_id"
272
+ )
273
+ else:
274
+ custom_model = gr.Textbox(
275
+ label="Your custom model ID",
276
+ placeholder="your_username/your_trained_model_name",
277
+ info="Make sure your model is set to PUBLIC"
278
+ )
279
+
280
+ weight_name = gr.Dropdown(
281
+ label="Safetensors file",
282
+ #value="pytorch_lora_weights.safetensors",
283
+ info="specify which one if model has several .safetensors files",
284
+ allow_custom_value=True,
285
+ visible = False
286
+ )
287
+ with gr.Column():
288
+ with gr.Group():
289
+ load_model_btn = gr.Button("Load my model", elem_id="load_model_btn")
290
+ previous_model = gr.Textbox(
291
+ visible = False
292
+ )
293
+ model_status = gr.Textbox(
294
+ label = "model status",
295
+ show_label = False,
296
+ elem_id = "status_info"
297
+ )
298
+ trigger_word = gr.Textbox(label="Trigger word", interactive=False, visible=False)
299
+
300
+ image_in = gr.Image(source="upload", type="filepath")
301
+
302
+ with gr.Row():
303
+
304
+ with gr.Column():
305
+ with gr.Group():
306
+ prompt = gr.Textbox(label="Prompt")
307
+ negative_prompt = gr.Textbox(label="Negative prompt", value="extra digit, fewer digits, cropped, worst quality, low quality, glitch, deformed, mutated, ugly, disfigured")
308
+ with gr.Group():
309
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=7.5)
310
+ inf_steps = gr.Slider(label="Inference Steps", minimum="25", maximum="50", step=1, value=25)
311
+ custom_lora_weight = gr.Slider(label="Custom model weights", minimum=0.1, maximum=0.9, step=0.1, value=0.9)
312
+
313
+ with gr.Column():
314
+ with gr.Group():
315
+ preprocessor = gr.Dropdown(label="Preprocessor", choices=["canny"], value="canny", interactive=False, info="For the moment, only canny is available")
316
+ controlnet_conditioning_scale = gr.Slider(label="Controlnet conditioning Scale", minimum=0.1, maximum=0.9, step=0.01, value=0.5)
317
+ with gr.Group():
318
+ seed = gr.Slider(
319
+ label="Seed",
320
+ info = "-1 denotes a random seed",
321
+ minimum=-1,
322
+ maximum=423538377342,
323
+ step=1,
324
+ value=-1
325
+ )
326
+ last_used_seed = gr.Number(
327
+ label = "Last used seed",
328
+ info = "the seed used in the last generation",
329
+ )
330
+
331
+
332
+ submit_btn = gr.Button("Submit")
333
+
334
+ result = gr.Image(label="Result")
335
+
336
+ use_custom_model.change(
337
+ fn = check_use_custom_or_no,
338
+ inputs =[use_custom_model],
339
+ outputs = [custom_model_box],
340
+ queue = False
341
+ )
342
+ custom_model.blur(
343
+ fn=custom_model_changed,
344
+ inputs = [custom_model, previous_model],
345
+ outputs = [model_status],
346
+ queue = False
347
+ )
348
+ load_model_btn.click(
349
+ fn = load_model,
350
+ inputs=[custom_model],
351
+ outputs = [previous_model, model_status, weight_name, trigger_word],
352
+ queue = False
353
+ )
354
+ submit_btn.click(
355
+ fn = infer,
356
+ inputs = [use_custom_model, custom_model, weight_name, custom_lora_weight, image_in, prompt, negative_prompt, preprocessor, controlnet_conditioning_scale, guidance_scale, inf_steps, seed],
357
+ outputs = [result, last_used_seed]
358
+ )
359
+
360
+ return demo
361
+
362
+
363
+ #demo.queue(max_size=12).launch(share=True)
sdxl/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision==0.15.2
3
+ invisible_watermark
4
+ accelerate
5
+ transformers
6
+ safetensors
7
+ opencv-python
8
+ git+https://github.com/huggingface/diffusers.git
train_dreambooth_lora_sdxl.py ADDED
@@ -0,0 +1,1508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import gc
18
+ import hashlib
19
+ import itertools
20
+ import logging
21
+ import math
22
+ import os
23
+ import shutil
24
+ import warnings
25
+ from pathlib import Path
26
+ from typing import Dict
27
+
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+ import transformers
33
+ from accelerate import Accelerator
34
+ from accelerate.logging import get_logger
35
+ from accelerate.utils import ProjectConfiguration, set_seed
36
+ from huggingface_hub import create_repo, upload_folder
37
+ from packaging import version
38
+ from PIL import Image
39
+ from PIL.ImageOps import exif_transpose
40
+ from torch.utils.data import Dataset
41
+ from torchvision import transforms
42
+ from tqdm.auto import tqdm
43
+ from transformers import AutoTokenizer, PretrainedConfig
44
+
45
+ import diffusers
46
+ from diffusers import (
47
+ AutoencoderKL,
48
+ DDPMScheduler,
49
+ DPMSolverMultistepScheduler,
50
+ StableDiffusionXLPipeline,
51
+ UNet2DConditionModel,
52
+ )
53
+ from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict
54
+ from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
55
+ from diffusers.optimization import get_scheduler
56
+ from diffusers.utils import check_min_version, is_wandb_available
57
+ from diffusers.utils.import_utils import is_xformers_available
58
+
59
+
60
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
61
+ check_min_version("0.22.0.dev0")
62
+
63
+ logger = get_logger(__name__)
64
+
65
+ def save_tempo_model_card(
66
+ repo_id: str, dataset_id=str, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None, last_checkpoint=str
67
+ ):
68
+
69
+ yaml = f"""
70
+ ---
71
+ base_model: {base_model}
72
+ instance_prompt: {prompt}
73
+ tags:
74
+ - stable-diffusion-xl
75
+ - stable-diffusion-xl-diffusers
76
+ - text-to-image
77
+ - diffusers
78
+ - lora
79
+ inference: false
80
+ datasets:
81
+ - {dataset_id}
82
+ ---
83
+ """
84
+ model_card = f"""
85
+ # LoRA DreamBooth - {repo_id}
86
+ ## MODEL IS CURRENTLY TRAINING ...
87
+ Last checkpoint saved: {last_checkpoint}
88
+ These are LoRA adaption weights for {base_model} trained on @fffiloni's SD-XL trainer.
89
+ The weights were trained on the concept prompt:
90
+ ```
91
+ {prompt}
92
+ ```
93
+ Use this keyword to trigger your custom model in your prompts.
94
+ LoRA for the text encoder was enabled: {train_text_encoder}.
95
+ Special VAE used for training: {vae_path}.
96
+ ## Usage
97
+ Make sure to upgrade diffusers to >= 0.19.0:
98
+ ```
99
+ pip install diffusers --upgrade
100
+ ```
101
+ In addition make sure to install transformers, safetensors, accelerate as well as the invisible watermark:
102
+ ```
103
+ pip install invisible_watermark transformers accelerate safetensors
104
+ ```
105
+ To just use the base model, you can run:
106
+ ```python
107
+ import torch
108
+ from diffusers import DiffusionPipeline, AutoencoderKL
109
+ device = "cuda" if torch.cuda.is_available() else "cpu"
110
+ vae = AutoencoderKL.from_pretrained('{vae_path}', torch_dtype=torch.float16)
111
+ pipe = DiffusionPipeline.from_pretrained(
112
+ "stabilityai/stable-diffusion-xl-base-1.0",
113
+ vae=vae, torch_dtype=torch.float16, variant="fp16",
114
+ use_safetensors=True
115
+ )
116
+ pipe.to(device)
117
+ # This is where you load your trained weights
118
+ specific_safetensors = "pytorch_lora_weights.safetensors"
119
+ lora_scale = 0.9
120
+ pipe.load_lora_weights(
121
+ '{repo_id}',
122
+ weight_name = specific_safetensors,
123
+ # use_auth_token = True
124
+ )
125
+ prompt = "A majestic {prompt} jumping from a big stone at night"
126
+ image = pipe(
127
+ prompt=prompt,
128
+ num_inference_steps=50,
129
+ cross_attention_kwargs={{"scale": lora_scale}}
130
+ ).images[0]
131
+ ```
132
+ """
133
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
134
+ f.write(yaml + model_card)
135
+
136
+ def save_model_card(
137
+ repo_id: str, images=None, dataset_id=str, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
138
+ ):
139
+ img_str = ""
140
+ for i, image in enumerate(images):
141
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
142
+ img_str += f"![img_{i}](./image_{i}.png)\n"
143
+
144
+ yaml = f"""
145
+ ---
146
+ base_model: {base_model}
147
+ instance_prompt: {prompt}
148
+ tags:
149
+ - stable-diffusion-xl
150
+ - stable-diffusion-xl-diffusers
151
+ - text-to-image
152
+ - diffusers
153
+ - lora
154
+ inference: false
155
+ datasets:
156
+ - {dataset_id}
157
+ ---
158
+ """
159
+ model_card = f"""
160
+ # LoRA DreamBooth - {repo_id}
161
+ These are LoRA adaption weights for {base_model} trained on @fffiloni's SD-XL trainer.
162
+ The weights were trained on the concept prompt:
163
+ ```
164
+ {prompt}
165
+ ```
166
+ Use this keyword to trigger your custom model in your prompts.
167
+ LoRA for the text encoder was enabled: {train_text_encoder}.
168
+ Special VAE used for training: {vae_path}.
169
+ ## Usage
170
+ Make sure to upgrade diffusers to >= 0.19.0:
171
+ ```
172
+ pip install diffusers --upgrade
173
+ ```
174
+ In addition make sure to install transformers, safetensors, accelerate as well as the invisible watermark:
175
+ ```
176
+ pip install invisible_watermark transformers accelerate safetensors
177
+ ```
178
+ To just use the base model, you can run:
179
+ ```python
180
+ import torch
181
+ from diffusers import DiffusionPipeline, AutoencoderKL
182
+ device = "cuda" if torch.cuda.is_available() else "cpu"
183
+ vae = AutoencoderKL.from_pretrained('{vae_path}', torch_dtype=torch.float16)
184
+ pipe = DiffusionPipeline.from_pretrained(
185
+ "stabilityai/stable-diffusion-xl-base-1.0",
186
+ vae=vae, torch_dtype=torch.float16, variant="fp16",
187
+ use_safetensors=True
188
+ )
189
+ pipe.to(device)
190
+ # This is where you load your trained weights
191
+ specific_safetensors = "pytorch_lora_weights.safetensors"
192
+ lora_scale = 0.9
193
+ pipe.load_lora_weights(
194
+ '{repo_id}',
195
+ weight_name = specific_safetensors,
196
+ # use_auth_token = True
197
+ )
198
+ prompt = "A majestic {prompt} jumping from a big stone at night"
199
+ image = pipe(
200
+ prompt=prompt,
201
+ num_inference_steps=50,
202
+ cross_attention_kwargs={{"scale": lora_scale}}
203
+ ).images[0]
204
+ ```
205
+ """
206
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
207
+ f.write(yaml + model_card)
208
+
209
+
210
+ def import_model_class_from_model_name_or_path(
211
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
212
+ ):
213
+ text_encoder_config = PretrainedConfig.from_pretrained(
214
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
215
+ )
216
+ model_class = text_encoder_config.architectures[0]
217
+
218
+ if model_class == "CLIPTextModel":
219
+ from transformers import CLIPTextModel
220
+
221
+ return CLIPTextModel
222
+ elif model_class == "CLIPTextModelWithProjection":
223
+ from transformers import CLIPTextModelWithProjection
224
+
225
+ return CLIPTextModelWithProjection
226
+ else:
227
+ raise ValueError(f"{model_class} is not supported.")
228
+
229
+
230
+ def parse_args(input_args=None):
231
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
232
+ parser.add_argument(
233
+ "--pretrained_model_name_or_path",
234
+ type=str,
235
+ default=None,
236
+ required=True,
237
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
238
+ )
239
+ parser.add_argument(
240
+ "--pretrained_vae_model_name_or_path",
241
+ type=str,
242
+ default=None,
243
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
244
+ )
245
+ parser.add_argument(
246
+ "--revision",
247
+ type=str,
248
+ default=None,
249
+ required=False,
250
+ help="Revision of pretrained model identifier from huggingface.co/models.",
251
+ )
252
+ parser.add_argument(
253
+ "--dataset_id",
254
+ type=str,
255
+ default=None,
256
+ required=True,
257
+ help="The dataset ID you want to train images from",
258
+ )
259
+ parser.add_argument(
260
+ "--instance_data_dir",
261
+ type=str,
262
+ default=None,
263
+ required=True,
264
+ help="A folder containing the training data of instance images.",
265
+ )
266
+ parser.add_argument(
267
+ "--class_data_dir",
268
+ type=str,
269
+ default=None,
270
+ required=False,
271
+ help="A folder containing the training data of class images.",
272
+ )
273
+ parser.add_argument(
274
+ "--instance_prompt",
275
+ type=str,
276
+ default=None,
277
+ required=True,
278
+ help="The prompt with identifier specifying the instance",
279
+ )
280
+ parser.add_argument(
281
+ "--class_prompt",
282
+ type=str,
283
+ default=None,
284
+ help="The prompt to specify images in the same class as provided instance images.",
285
+ )
286
+ parser.add_argument(
287
+ "--validation_prompt",
288
+ type=str,
289
+ default=None,
290
+ help="A prompt that is used during validation to verify that the model is learning.",
291
+ )
292
+ parser.add_argument(
293
+ "--num_validation_images",
294
+ type=int,
295
+ default=4,
296
+ help="Number of images that should be generated during validation with `validation_prompt`.",
297
+ )
298
+ parser.add_argument(
299
+ "--validation_epochs",
300
+ type=int,
301
+ default=50,
302
+ help=(
303
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
304
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
305
+ ),
306
+ )
307
+ parser.add_argument(
308
+ "--with_prior_preservation",
309
+ default=False,
310
+ action="store_true",
311
+ help="Flag to add prior preservation loss.",
312
+ )
313
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
314
+ parser.add_argument(
315
+ "--num_class_images",
316
+ type=int,
317
+ default=100,
318
+ help=(
319
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
320
+ " class_data_dir, additional images will be sampled with class_prompt."
321
+ ),
322
+ )
323
+ parser.add_argument(
324
+ "--output_dir",
325
+ type=str,
326
+ default="lora-dreambooth-model",
327
+ help="The output directory where the model predictions and checkpoints will be written.",
328
+ )
329
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
330
+ parser.add_argument(
331
+ "--resolution",
332
+ type=int,
333
+ default=1024,
334
+ help=(
335
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
336
+ " resolution"
337
+ ),
338
+ )
339
+ parser.add_argument(
340
+ "--crops_coords_top_left_h",
341
+ type=int,
342
+ default=0,
343
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
344
+ )
345
+ parser.add_argument(
346
+ "--crops_coords_top_left_w",
347
+ type=int,
348
+ default=0,
349
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
350
+ )
351
+ parser.add_argument(
352
+ "--center_crop",
353
+ default=False,
354
+ action="store_true",
355
+ help=(
356
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
357
+ " cropped. The images will be resized to the resolution first before cropping."
358
+ ),
359
+ )
360
+ parser.add_argument(
361
+ "--train_text_encoder",
362
+ action="store_true",
363
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
364
+ )
365
+ parser.add_argument(
366
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
367
+ )
368
+ parser.add_argument(
369
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
370
+ )
371
+ parser.add_argument("--num_train_epochs", type=int, default=1)
372
+ parser.add_argument(
373
+ "--max_train_steps",
374
+ type=int,
375
+ default=None,
376
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
377
+ )
378
+ parser.add_argument(
379
+ "--checkpointing_steps",
380
+ type=int,
381
+ default=500,
382
+ help=(
383
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
384
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
385
+ " training using `--resume_from_checkpoint`."
386
+ ),
387
+ )
388
+ parser.add_argument(
389
+ "--checkpoints_total_limit",
390
+ type=int,
391
+ default=None,
392
+ help=("Max number of checkpoints to store."),
393
+ )
394
+ parser.add_argument(
395
+ "--resume_from_checkpoint",
396
+ type=str,
397
+ default=None,
398
+ help=(
399
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
400
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
401
+ ),
402
+ )
403
+ parser.add_argument(
404
+ "--gradient_accumulation_steps",
405
+ type=int,
406
+ default=1,
407
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
408
+ )
409
+ parser.add_argument(
410
+ "--gradient_checkpointing",
411
+ action="store_true",
412
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
413
+ )
414
+ parser.add_argument(
415
+ "--learning_rate",
416
+ type=float,
417
+ default=5e-4,
418
+ help="Initial learning rate (after the potential warmup period) to use.",
419
+ )
420
+ parser.add_argument(
421
+ "--scale_lr",
422
+ action="store_true",
423
+ default=False,
424
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
425
+ )
426
+ parser.add_argument(
427
+ "--lr_scheduler",
428
+ type=str,
429
+ default="constant",
430
+ help=(
431
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
432
+ ' "constant", "constant_with_warmup"]'
433
+ ),
434
+ )
435
+ parser.add_argument(
436
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
437
+ )
438
+ parser.add_argument(
439
+ "--lr_num_cycles",
440
+ type=int,
441
+ default=1,
442
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
443
+ )
444
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
445
+ parser.add_argument(
446
+ "--dataloader_num_workers",
447
+ type=int,
448
+ default=0,
449
+ help=(
450
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
451
+ ),
452
+ )
453
+ parser.add_argument(
454
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
455
+ )
456
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
457
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
458
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
459
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
460
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
461
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
462
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
463
+ parser.add_argument(
464
+ "--hub_model_id",
465
+ type=str,
466
+ default=None,
467
+ help="The name of the repository to keep in sync with the local `output_dir`.",
468
+ )
469
+ parser.add_argument(
470
+ "--logging_dir",
471
+ type=str,
472
+ default="logs",
473
+ help=(
474
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
475
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
476
+ ),
477
+ )
478
+ parser.add_argument(
479
+ "--allow_tf32",
480
+ action="store_true",
481
+ help=(
482
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
483
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
484
+ ),
485
+ )
486
+ parser.add_argument(
487
+ "--report_to",
488
+ type=str,
489
+ default="tensorboard",
490
+ help=(
491
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
492
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
493
+ ),
494
+ )
495
+ parser.add_argument(
496
+ "--mixed_precision",
497
+ type=str,
498
+ default=None,
499
+ choices=["no", "fp16", "bf16"],
500
+ help=(
501
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
502
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
503
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
504
+ ),
505
+ )
506
+ parser.add_argument(
507
+ "--prior_generation_precision",
508
+ type=str,
509
+ default=None,
510
+ choices=["no", "fp32", "fp16", "bf16"],
511
+ help=(
512
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
513
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
514
+ ),
515
+ )
516
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
517
+ parser.add_argument(
518
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
519
+ )
520
+ parser.add_argument(
521
+ "--rank",
522
+ type=int,
523
+ default=4,
524
+ help=("The dimension of the LoRA update matrices."),
525
+ )
526
+
527
+ if input_args is not None:
528
+ args = parser.parse_args(input_args)
529
+ else:
530
+ args = parser.parse_args()
531
+
532
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
533
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
534
+ args.local_rank = env_local_rank
535
+
536
+ if args.with_prior_preservation:
537
+ if args.class_data_dir is None:
538
+ raise ValueError("You must specify a data directory for class images.")
539
+ if args.class_prompt is None:
540
+ raise ValueError("You must specify prompt for class images.")
541
+ else:
542
+ # logger is not available yet
543
+ if args.class_data_dir is not None:
544
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
545
+ if args.class_prompt is not None:
546
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
547
+
548
+ return args
549
+
550
+
551
+ class DreamBoothDataset(Dataset):
552
+ """
553
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
554
+ It pre-processes the images.
555
+ """
556
+
557
+ def __init__(
558
+ self,
559
+ instance_data_root,
560
+ class_data_root=None,
561
+ class_num=None,
562
+ size=1024,
563
+ center_crop=False,
564
+ ):
565
+ self.size = size
566
+ self.center_crop = center_crop
567
+
568
+ self.instance_data_root = Path(instance_data_root)
569
+ if not self.instance_data_root.exists():
570
+ raise ValueError("Instance images root doesn't exists.")
571
+
572
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
573
+ self.num_instance_images = len(self.instance_images_path)
574
+ self._length = self.num_instance_images
575
+
576
+ if class_data_root is not None:
577
+ self.class_data_root = Path(class_data_root)
578
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
579
+ self.class_images_path = list(self.class_data_root.iterdir())
580
+ if class_num is not None:
581
+ self.num_class_images = min(len(self.class_images_path), class_num)
582
+ else:
583
+ self.num_class_images = len(self.class_images_path)
584
+ self._length = max(self.num_class_images, self.num_instance_images)
585
+ else:
586
+ self.class_data_root = None
587
+
588
+ self.image_transforms = transforms.Compose(
589
+ [
590
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
591
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
592
+ transforms.ToTensor(),
593
+ transforms.Normalize([0.5], [0.5]),
594
+ ]
595
+ )
596
+
597
+ def __len__(self):
598
+ return self._length
599
+
600
+ def __getitem__(self, index):
601
+ example = {}
602
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
603
+ instance_image = exif_transpose(instance_image)
604
+
605
+ if not instance_image.mode == "RGB":
606
+ instance_image = instance_image.convert("RGB")
607
+ example["instance_images"] = self.image_transforms(instance_image)
608
+
609
+ if self.class_data_root:
610
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
611
+ class_image = exif_transpose(class_image)
612
+
613
+ if not class_image.mode == "RGB":
614
+ class_image = class_image.convert("RGB")
615
+ example["class_images"] = self.image_transforms(class_image)
616
+
617
+ return example
618
+
619
+
620
+ def collate_fn(examples, with_prior_preservation=False):
621
+ pixel_values = [example["instance_images"] for example in examples]
622
+
623
+ # Concat class and instance examples for prior preservation.
624
+ # We do this to avoid doing two forward passes.
625
+ if with_prior_preservation:
626
+ pixel_values += [example["class_images"] for example in examples]
627
+
628
+ pixel_values = torch.stack(pixel_values)
629
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
630
+
631
+ batch = {"pixel_values": pixel_values}
632
+ return batch
633
+
634
+
635
+ class PromptDataset(Dataset):
636
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
637
+
638
+ def __init__(self, prompt, num_samples):
639
+ self.prompt = prompt
640
+ self.num_samples = num_samples
641
+
642
+ def __len__(self):
643
+ return self.num_samples
644
+
645
+ def __getitem__(self, index):
646
+ example = {}
647
+ example["prompt"] = self.prompt
648
+ example["index"] = index
649
+ return example
650
+
651
+
652
+ def tokenize_prompt(tokenizer, prompt):
653
+ text_inputs = tokenizer(
654
+ prompt,
655
+ padding="max_length",
656
+ max_length=tokenizer.model_max_length,
657
+ truncation=True,
658
+ return_tensors="pt",
659
+ )
660
+ text_input_ids = text_inputs.input_ids
661
+ return text_input_ids
662
+
663
+
664
+ # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
665
+ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
666
+ prompt_embeds_list = []
667
+
668
+ for i, text_encoder in enumerate(text_encoders):
669
+ if tokenizers is not None:
670
+ tokenizer = tokenizers[i]
671
+ text_input_ids = tokenize_prompt(tokenizer, prompt)
672
+ else:
673
+ assert text_input_ids_list is not None
674
+ text_input_ids = text_input_ids_list[i]
675
+
676
+ prompt_embeds = text_encoder(
677
+ text_input_ids.to(text_encoder.device),
678
+ output_hidden_states=True,
679
+ )
680
+
681
+ # We are only ALWAYS interested in the pooled output of the final text encoder
682
+ pooled_prompt_embeds = prompt_embeds[0]
683
+ prompt_embeds = prompt_embeds.hidden_states[-2]
684
+ bs_embed, seq_len, _ = prompt_embeds.shape
685
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
686
+ prompt_embeds_list.append(prompt_embeds)
687
+
688
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
689
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
690
+ return prompt_embeds, pooled_prompt_embeds
691
+
692
+
693
+ def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
694
+ """
695
+ Returns:
696
+ a state dict containing just the attention processor parameters.
697
+ """
698
+ attn_processors = unet.attn_processors
699
+
700
+ attn_processors_state_dict = {}
701
+
702
+ for attn_processor_key, attn_processor in attn_processors.items():
703
+ for parameter_key, parameter in attn_processor.state_dict().items():
704
+ attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
705
+
706
+ return attn_processors_state_dict
707
+
708
+
709
+ def main(args):
710
+ logging_dir = Path(args.output_dir, args.logging_dir)
711
+
712
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
713
+
714
+ accelerator = Accelerator(
715
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
716
+ mixed_precision=args.mixed_precision,
717
+ log_with=args.report_to,
718
+ project_config=accelerator_project_config,
719
+ )
720
+
721
+ if args.report_to == "wandb":
722
+ if not is_wandb_available():
723
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
724
+ import wandb
725
+
726
+ # Make one log on every process with the configuration for debugging.
727
+ logging.basicConfig(
728
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
729
+ datefmt="%m/%d/%Y %H:%M:%S",
730
+ level=logging.INFO,
731
+ )
732
+ logger.info(accelerator.state, main_process_only=False)
733
+ if accelerator.is_local_main_process:
734
+ transformers.utils.logging.set_verbosity_warning()
735
+ diffusers.utils.logging.set_verbosity_info()
736
+ else:
737
+ transformers.utils.logging.set_verbosity_error()
738
+ diffusers.utils.logging.set_verbosity_error()
739
+
740
+ # If passed along, set the training seed now.
741
+ if args.seed is not None:
742
+ set_seed(args.seed)
743
+
744
+ # Generate class images if prior preservation is enabled.
745
+ if args.with_prior_preservation:
746
+ class_images_dir = Path(args.class_data_dir)
747
+ if not class_images_dir.exists():
748
+ class_images_dir.mkdir(parents=True)
749
+ cur_class_images = len(list(class_images_dir.iterdir()))
750
+
751
+ if cur_class_images < args.num_class_images:
752
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
753
+ if args.prior_generation_precision == "fp32":
754
+ torch_dtype = torch.float32
755
+ elif args.prior_generation_precision == "fp16":
756
+ torch_dtype = torch.float16
757
+ elif args.prior_generation_precision == "bf16":
758
+ torch_dtype = torch.bfloat16
759
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
760
+ args.pretrained_model_name_or_path,
761
+ torch_dtype=torch_dtype,
762
+ revision=args.revision,
763
+ )
764
+ pipeline.set_progress_bar_config(disable=True)
765
+
766
+ num_new_images = args.num_class_images - cur_class_images
767
+ logger.info(f"Number of class images to sample: {num_new_images}.")
768
+
769
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
770
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
771
+
772
+ sample_dataloader = accelerator.prepare(sample_dataloader)
773
+ pipeline.to(accelerator.device)
774
+
775
+ for example in tqdm(
776
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
777
+ ):
778
+ images = pipeline(example["prompt"]).images
779
+
780
+ for i, image in enumerate(images):
781
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
782
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
783
+ image.save(image_filename)
784
+
785
+ del pipeline
786
+ if torch.cuda.is_available():
787
+ torch.cuda.empty_cache()
788
+
789
+ # Handle the repository creation
790
+ if accelerator.is_main_process:
791
+ if args.output_dir is not None:
792
+ os.makedirs(args.output_dir, exist_ok=True)
793
+
794
+ if args.push_to_hub:
795
+ repo_id = create_repo(
796
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
797
+ ).repo_id
798
+
799
+ # Load the tokenizers
800
+ tokenizer_one = AutoTokenizer.from_pretrained(
801
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
802
+ )
803
+ tokenizer_two = AutoTokenizer.from_pretrained(
804
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
805
+ )
806
+
807
+ # import correct text encoder classes
808
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
809
+ args.pretrained_model_name_or_path, args.revision
810
+ )
811
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
812
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
813
+ )
814
+
815
+ # Load scheduler and models
816
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
817
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
818
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
819
+ )
820
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
821
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
822
+ )
823
+ vae_path = (
824
+ args.pretrained_model_name_or_path
825
+ if args.pretrained_vae_model_name_or_path is None
826
+ else args.pretrained_vae_model_name_or_path
827
+ )
828
+ vae = AutoencoderKL.from_pretrained(
829
+ vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
830
+ )
831
+ unet = UNet2DConditionModel.from_pretrained(
832
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
833
+ )
834
+
835
+ # We only train the additional adapter LoRA layers
836
+ vae.requires_grad_(False)
837
+ text_encoder_one.requires_grad_(False)
838
+ text_encoder_two.requires_grad_(False)
839
+ unet.requires_grad_(False)
840
+
841
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
842
+ # as these weights are only used for inference, keeping weights in full precision is not required.
843
+ weight_dtype = torch.float32
844
+ if accelerator.mixed_precision == "fp16":
845
+ weight_dtype = torch.float16
846
+ elif accelerator.mixed_precision == "bf16":
847
+ weight_dtype = torch.bfloat16
848
+
849
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
850
+ unet.to(accelerator.device, dtype=weight_dtype)
851
+
852
+ # The VAE is always in float32 to avoid NaN losses.
853
+ vae.to(accelerator.device, dtype=torch.float32)
854
+
855
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
856
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
857
+
858
+ if args.enable_xformers_memory_efficient_attention:
859
+ if is_xformers_available():
860
+ import xformers
861
+
862
+ xformers_version = version.parse(xformers.__version__)
863
+ if xformers_version == version.parse("0.0.16"):
864
+ logger.warn(
865
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
866
+ )
867
+ unet.enable_xformers_memory_efficient_attention()
868
+ else:
869
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
870
+
871
+ if args.gradient_checkpointing:
872
+ unet.enable_gradient_checkpointing()
873
+ if args.train_text_encoder:
874
+ text_encoder_one.gradient_checkpointing_enable()
875
+ text_encoder_two.gradient_checkpointing_enable()
876
+
877
+ # now we will add new LoRA weights to the attention layers
878
+ # Set correct lora layers
879
+ unet_lora_attn_procs = {}
880
+ unet_lora_parameters = []
881
+ for name, attn_processor in unet.attn_processors.items():
882
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
883
+ if name.startswith("mid_block"):
884
+ hidden_size = unet.config.block_out_channels[-1]
885
+ elif name.startswith("up_blocks"):
886
+ block_id = int(name[len("up_blocks.")])
887
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
888
+ elif name.startswith("down_blocks"):
889
+ block_id = int(name[len("down_blocks.")])
890
+ hidden_size = unet.config.block_out_channels[block_id]
891
+
892
+ lora_attn_processor_class = (
893
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
894
+ )
895
+ module = lora_attn_processor_class(
896
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank
897
+ )
898
+ unet_lora_attn_procs[name] = module
899
+ unet_lora_parameters.extend(module.parameters())
900
+
901
+ unet.set_attn_processor(unet_lora_attn_procs)
902
+
903
+ # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
904
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
905
+ if args.train_text_encoder:
906
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
907
+ text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
908
+ text_encoder_one, dtype=torch.float32, rank=args.rank
909
+ )
910
+ text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
911
+ text_encoder_two, dtype=torch.float32, rank=args.rank
912
+ )
913
+
914
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
915
+ def save_model_hook(models, weights, output_dir):
916
+ if accelerator.is_main_process:
917
+ # there are only two options here. Either are just the unet attn processor layers
918
+ # or there are the unet and text encoder atten layers
919
+ unet_lora_layers_to_save = None
920
+ text_encoder_one_lora_layers_to_save = None
921
+ text_encoder_two_lora_layers_to_save = None
922
+
923
+ for model in models:
924
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
925
+ unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
926
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
927
+ text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
928
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
929
+ text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
930
+ else:
931
+ raise ValueError(f"unexpected save model: {model.__class__}")
932
+
933
+ # make sure to pop weight so that corresponding model is not saved again
934
+ weights.pop()
935
+
936
+ StableDiffusionXLPipeline.save_lora_weights(
937
+ output_dir,
938
+ unet_lora_layers=unet_lora_layers_to_save,
939
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
940
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
941
+ )
942
+
943
+ def load_model_hook(models, input_dir):
944
+ unet_ = None
945
+ text_encoder_one_ = None
946
+ text_encoder_two_ = None
947
+
948
+ while len(models) > 0:
949
+ model = models.pop()
950
+
951
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
952
+ unet_ = model
953
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
954
+ text_encoder_one_ = model
955
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
956
+ text_encoder_two_ = model
957
+ else:
958
+ raise ValueError(f"unexpected save model: {model.__class__}")
959
+
960
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
961
+ LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
962
+
963
+ text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
964
+ LoraLoaderMixin.load_lora_into_text_encoder(
965
+ text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
966
+ )
967
+
968
+ text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
969
+ LoraLoaderMixin.load_lora_into_text_encoder(
970
+ text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
971
+ )
972
+
973
+ accelerator.register_save_state_pre_hook(save_model_hook)
974
+ accelerator.register_load_state_pre_hook(load_model_hook)
975
+
976
+ # Enable TF32 for faster training on Ampere GPUs,
977
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
978
+ if args.allow_tf32:
979
+ torch.backends.cuda.matmul.allow_tf32 = True
980
+
981
+ if args.scale_lr:
982
+ args.learning_rate = (
983
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
984
+ )
985
+
986
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
987
+ if args.use_8bit_adam:
988
+ try:
989
+ import bitsandbytes as bnb
990
+ except ImportError:
991
+ raise ImportError(
992
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
993
+ )
994
+
995
+ optimizer_class = bnb.optim.AdamW8bit
996
+ else:
997
+ optimizer_class = torch.optim.AdamW
998
+
999
+ # Optimizer creation
1000
+ params_to_optimize = (
1001
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
1002
+ if args.train_text_encoder
1003
+ else unet_lora_parameters
1004
+ )
1005
+ optimizer = optimizer_class(
1006
+ params_to_optimize,
1007
+ lr=args.learning_rate,
1008
+ betas=(args.adam_beta1, args.adam_beta2),
1009
+ weight_decay=args.adam_weight_decay,
1010
+ eps=args.adam_epsilon,
1011
+ )
1012
+
1013
+ # Computes additional embeddings/ids required by the SDXL UNet.
1014
+ # regular text emebddings (when `train_text_encoder` is not True)
1015
+ # pooled text embeddings
1016
+ # time ids
1017
+
1018
+ def compute_time_ids():
1019
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1020
+ original_size = (args.resolution, args.resolution)
1021
+ target_size = (args.resolution, args.resolution)
1022
+ crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
1023
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
1024
+ add_time_ids = torch.tensor([add_time_ids])
1025
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
1026
+ return add_time_ids
1027
+
1028
+ if not args.train_text_encoder:
1029
+ tokenizers = [tokenizer_one, tokenizer_two]
1030
+ text_encoders = [text_encoder_one, text_encoder_two]
1031
+
1032
+ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1033
+ with torch.no_grad():
1034
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
1035
+ prompt_embeds = prompt_embeds.to(accelerator.device)
1036
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
1037
+ return prompt_embeds, pooled_prompt_embeds
1038
+
1039
+ # Handle instance prompt.
1040
+ instance_time_ids = compute_time_ids()
1041
+ if not args.train_text_encoder:
1042
+ instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
1043
+ args.instance_prompt, text_encoders, tokenizers
1044
+ )
1045
+
1046
+ # Handle class prompt for prior-preservation.
1047
+ if args.with_prior_preservation:
1048
+ class_time_ids = compute_time_ids()
1049
+ if not args.train_text_encoder:
1050
+ class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
1051
+ args.class_prompt, text_encoders, tokenizers
1052
+ )
1053
+
1054
+ # Clear the memory here.
1055
+ if not args.train_text_encoder:
1056
+ del tokenizers, text_encoders
1057
+ gc.collect()
1058
+ torch.cuda.empty_cache()
1059
+
1060
+ # Pack the statically computed variables appropriately. This is so that we don't
1061
+ # have to pass them to the dataloader.
1062
+ add_time_ids = instance_time_ids
1063
+ if args.with_prior_preservation:
1064
+ add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
1065
+
1066
+ if not args.train_text_encoder:
1067
+ prompt_embeds = instance_prompt_hidden_states
1068
+ unet_add_text_embeds = instance_pooled_prompt_embeds
1069
+ if args.with_prior_preservation:
1070
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
1071
+ unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
1072
+ else:
1073
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
1074
+ tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
1075
+ if args.with_prior_preservation:
1076
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
1077
+ class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
1078
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
1079
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
1080
+
1081
+ # Dataset and DataLoaders creation:
1082
+ train_dataset = DreamBoothDataset(
1083
+ instance_data_root=args.instance_data_dir,
1084
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
1085
+ class_num=args.num_class_images,
1086
+ size=args.resolution,
1087
+ center_crop=args.center_crop,
1088
+ )
1089
+
1090
+ train_dataloader = torch.utils.data.DataLoader(
1091
+ train_dataset,
1092
+ batch_size=args.train_batch_size,
1093
+ shuffle=True,
1094
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
1095
+ num_workers=args.dataloader_num_workers,
1096
+ )
1097
+
1098
+ # Scheduler and math around the number of training steps.
1099
+ overrode_max_train_steps = False
1100
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1101
+ if args.max_train_steps is None:
1102
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1103
+ overrode_max_train_steps = True
1104
+
1105
+ lr_scheduler = get_scheduler(
1106
+ args.lr_scheduler,
1107
+ optimizer=optimizer,
1108
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1109
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
1110
+ num_cycles=args.lr_num_cycles,
1111
+ power=args.lr_power,
1112
+ )
1113
+
1114
+ # Prepare everything with our `accelerator`.
1115
+ if args.train_text_encoder:
1116
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1117
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
1118
+ )
1119
+ else:
1120
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1121
+ unet, optimizer, train_dataloader, lr_scheduler
1122
+ )
1123
+
1124
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1125
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1126
+ if overrode_max_train_steps:
1127
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1128
+ # Afterwards we recalculate our number of training epochs
1129
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1130
+
1131
+ # We need to initialize the trackers we use, and also store our configuration.
1132
+ # The trackers initializes automatically on the main process.
1133
+ if accelerator.is_main_process:
1134
+ accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args))
1135
+
1136
+ # Train!
1137
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1138
+
1139
+ logger.info("***** Running training *****")
1140
+ logger.info(f" Num examples = {len(train_dataset)}")
1141
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1142
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1143
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1144
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1145
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1146
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1147
+ global_step = 0
1148
+ first_epoch = 0
1149
+
1150
+ # Potentially load in the weights and states from a previous save
1151
+ if args.resume_from_checkpoint:
1152
+ if args.resume_from_checkpoint != "latest":
1153
+ path = os.path.basename(args.resume_from_checkpoint)
1154
+ else:
1155
+ # Get the mos recent checkpoint
1156
+ dirs = os.listdir(args.output_dir)
1157
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1158
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1159
+ path = dirs[-1] if len(dirs) > 0 else None
1160
+
1161
+ if path is None:
1162
+ accelerator.print(
1163
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1164
+ )
1165
+ args.resume_from_checkpoint = None
1166
+ initial_global_step = 0
1167
+ else:
1168
+ accelerator.print(f"Resuming from checkpoint {path}")
1169
+ accelerator.load_state(os.path.join(args.output_dir, path))
1170
+ global_step = int(path.split("-")[1])
1171
+
1172
+ initial_global_step = global_step
1173
+ first_epoch = global_step // num_update_steps_per_epoch
1174
+
1175
+ else:
1176
+ initial_global_step = 0
1177
+
1178
+ progress_bar = tqdm(
1179
+ range(0, args.max_train_steps),
1180
+ initial=initial_global_step,
1181
+ desc="Steps",
1182
+ # Only show the progress bar once on each machine.
1183
+ disable=not accelerator.is_local_main_process,
1184
+ )
1185
+
1186
+ for epoch in range(first_epoch, args.num_train_epochs):
1187
+ # Print a message for each epoch
1188
+ print(f"Epoch {epoch}: Training in progress...")
1189
+ unet.train()
1190
+ if args.train_text_encoder:
1191
+ text_encoder_one.train()
1192
+ text_encoder_two.train()
1193
+ for step, batch in enumerate(train_dataloader):
1194
+ with accelerator.accumulate(unet):
1195
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1196
+
1197
+ # Convert images to latent space
1198
+ model_input = vae.encode(pixel_values).latent_dist.sample()
1199
+ model_input = model_input * vae.config.scaling_factor
1200
+ if args.pretrained_vae_model_name_or_path is None:
1201
+ model_input = model_input.to(weight_dtype)
1202
+
1203
+ # Sample noise that we'll add to the latents
1204
+ noise = torch.randn_like(model_input)
1205
+ bsz = model_input.shape[0]
1206
+ # Sample a random timestep for each image
1207
+ timesteps = torch.randint(
1208
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
1209
+ )
1210
+ timesteps = timesteps.long()
1211
+
1212
+ # Add noise to the model input according to the noise magnitude at each timestep
1213
+ # (this is the forward diffusion process)
1214
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1215
+
1216
+ # Calculate the elements to repeat depending on the use of prior-preservation.
1217
+ elems_to_repeat = bsz // 2 if args.with_prior_preservation else bsz
1218
+
1219
+ # Predict the noise residual
1220
+ if not args.train_text_encoder:
1221
+ unet_added_conditions = {
1222
+ "time_ids": add_time_ids.repeat(elems_to_repeat, 1),
1223
+ "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1),
1224
+ }
1225
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
1226
+ model_pred = unet(
1227
+ noisy_model_input,
1228
+ timesteps,
1229
+ prompt_embeds_input,
1230
+ added_cond_kwargs=unet_added_conditions,
1231
+ ).sample
1232
+ else:
1233
+ unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat, 1)}
1234
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
1235
+ text_encoders=[text_encoder_one, text_encoder_two],
1236
+ tokenizers=None,
1237
+ prompt=None,
1238
+ text_input_ids_list=[tokens_one, tokens_two],
1239
+ )
1240
+ unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)})
1241
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
1242
+ model_pred = unet(
1243
+ noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
1244
+ ).sample
1245
+
1246
+ # Get the target for loss depending on the prediction type
1247
+ if noise_scheduler.config.prediction_type == "epsilon":
1248
+ target = noise
1249
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1250
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
1251
+ else:
1252
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1253
+
1254
+ if args.with_prior_preservation:
1255
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
1256
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
1257
+ target, target_prior = torch.chunk(target, 2, dim=0)
1258
+
1259
+ # Compute instance loss
1260
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1261
+
1262
+ # Compute prior loss
1263
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
1264
+
1265
+ # Add the prior loss to the instance loss.
1266
+ loss = loss + args.prior_loss_weight * prior_loss
1267
+ else:
1268
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1269
+
1270
+ accelerator.backward(loss)
1271
+ if accelerator.sync_gradients:
1272
+ params_to_clip = (
1273
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
1274
+ if args.train_text_encoder
1275
+ else unet_lora_parameters
1276
+ )
1277
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1278
+ optimizer.step()
1279
+ lr_scheduler.step()
1280
+ optimizer.zero_grad()
1281
+
1282
+ # Checks if the accelerator has performed an optimization step behind the scenes
1283
+ if accelerator.sync_gradients:
1284
+ # Print a message for each step
1285
+ print(f"Step {global_step}/{args.max_train_steps}: Done")
1286
+ progress_bar.update(1)
1287
+ global_step += 1
1288
+
1289
+ if accelerator.is_main_process:
1290
+ if global_step % args.checkpointing_steps == 0:
1291
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1292
+ if args.checkpoints_total_limit is not None:
1293
+ checkpoints = os.listdir(args.output_dir)
1294
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1295
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1296
+
1297
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1298
+ if len(checkpoints) >= args.checkpoints_total_limit:
1299
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1300
+ removing_checkpoints = checkpoints[0:num_to_remove]
1301
+
1302
+ logger.info(
1303
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1304
+ )
1305
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1306
+
1307
+ for removing_checkpoint in removing_checkpoints:
1308
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1309
+ shutil.rmtree(removing_checkpoint)
1310
+
1311
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1312
+ accelerator.save_state(save_path)
1313
+ logger.info(f"Saved state to {save_path}")
1314
+
1315
+ save_tempo_model_card(
1316
+ repo_id,
1317
+ dataset_id=args.dataset_id,
1318
+ base_model=args.pretrained_model_name_or_path,
1319
+ train_text_encoder=args.train_text_encoder,
1320
+ prompt=args.instance_prompt,
1321
+ repo_folder=args.output_dir,
1322
+ vae_path=args.pretrained_vae_model_name_or_path,
1323
+ last_checkpoint = f"checkpoint-{global_step}"
1324
+ )
1325
+
1326
+ upload_folder(
1327
+ repo_id=repo_id,
1328
+ folder_path=args.output_dir,
1329
+ commit_message=f"saving checkpoint-{global_step}",
1330
+ ignore_patterns=["step_*", "epoch_*"],
1331
+ token=args.hub_token
1332
+ )
1333
+
1334
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1335
+ progress_bar.set_postfix(**logs)
1336
+ accelerator.log(logs, step=global_step)
1337
+
1338
+ if global_step >= args.max_train_steps:
1339
+ break
1340
+
1341
+ if accelerator.is_main_process:
1342
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
1343
+ logger.info(
1344
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1345
+ f" {args.validation_prompt}."
1346
+ )
1347
+ # create pipeline
1348
+ if not args.train_text_encoder:
1349
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
1350
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
1351
+ )
1352
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
1353
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
1354
+ )
1355
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
1356
+ args.pretrained_model_name_or_path,
1357
+ vae=vae,
1358
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
1359
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
1360
+ unet=accelerator.unwrap_model(unet),
1361
+ revision=args.revision,
1362
+ torch_dtype=weight_dtype,
1363
+ )
1364
+
1365
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1366
+ scheduler_args = {}
1367
+
1368
+ if "variance_type" in pipeline.scheduler.config:
1369
+ variance_type = pipeline.scheduler.config.variance_type
1370
+
1371
+ if variance_type in ["learned", "learned_range"]:
1372
+ variance_type = "fixed_small"
1373
+
1374
+ scheduler_args["variance_type"] = variance_type
1375
+
1376
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
1377
+ pipeline.scheduler.config, **scheduler_args
1378
+ )
1379
+
1380
+ pipeline = pipeline.to(accelerator.device)
1381
+ pipeline.set_progress_bar_config(disable=True)
1382
+
1383
+ # run inference
1384
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1385
+ pipeline_args = {"prompt": args.validation_prompt}
1386
+
1387
+ with torch.cuda.amp.autocast():
1388
+ images = [
1389
+ pipeline(**pipeline_args, generator=generator).images[0]
1390
+ for _ in range(args.num_validation_images)
1391
+ ]
1392
+
1393
+ for tracker in accelerator.trackers:
1394
+ if tracker.name == "tensorboard":
1395
+ np_images = np.stack([np.asarray(img) for img in images])
1396
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
1397
+ if tracker.name == "wandb":
1398
+ tracker.log(
1399
+ {
1400
+ "validation": [
1401
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1402
+ for i, image in enumerate(images)
1403
+ ]
1404
+ }
1405
+ )
1406
+
1407
+ del pipeline
1408
+ torch.cuda.empty_cache()
1409
+
1410
+ # Save the lora layers
1411
+ accelerator.wait_for_everyone()
1412
+ if accelerator.is_main_process:
1413
+ unet = accelerator.unwrap_model(unet)
1414
+ unet = unet.to(torch.float32)
1415
+ unet_lora_layers = unet_attn_processors_state_dict(unet)
1416
+
1417
+ if args.train_text_encoder:
1418
+ text_encoder_one = accelerator.unwrap_model(text_encoder_one)
1419
+ text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
1420
+ text_encoder_two = accelerator.unwrap_model(text_encoder_two)
1421
+ text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
1422
+ else:
1423
+ text_encoder_lora_layers = None
1424
+ text_encoder_2_lora_layers = None
1425
+
1426
+ StableDiffusionXLPipeline.save_lora_weights(
1427
+ save_directory=args.output_dir,
1428
+ unet_lora_layers=unet_lora_layers,
1429
+ text_encoder_lora_layers=text_encoder_lora_layers,
1430
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
1431
+ )
1432
+
1433
+ # Final inference
1434
+ # Load previous pipeline
1435
+ vae = AutoencoderKL.from_pretrained(
1436
+ vae_path,
1437
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
1438
+ revision=args.revision,
1439
+ torch_dtype=weight_dtype,
1440
+ )
1441
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
1442
+ args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
1443
+ )
1444
+
1445
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1446
+ scheduler_args = {}
1447
+
1448
+ if "variance_type" in pipeline.scheduler.config:
1449
+ variance_type = pipeline.scheduler.config.variance_type
1450
+
1451
+ if variance_type in ["learned", "learned_range"]:
1452
+ variance_type = "fixed_small"
1453
+
1454
+ scheduler_args["variance_type"] = variance_type
1455
+
1456
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
1457
+
1458
+ # load attention processors
1459
+ pipeline.load_lora_weights(args.output_dir)
1460
+
1461
+ # run inference
1462
+ images = []
1463
+ if args.validation_prompt and args.num_validation_images > 0:
1464
+ pipeline = pipeline.to(accelerator.device)
1465
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1466
+ images = [
1467
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
1468
+ for _ in range(args.num_validation_images)
1469
+ ]
1470
+
1471
+ for tracker in accelerator.trackers:
1472
+ if tracker.name == "tensorboard":
1473
+ np_images = np.stack([np.asarray(img) for img in images])
1474
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
1475
+ if tracker.name == "wandb":
1476
+ tracker.log(
1477
+ {
1478
+ "test": [
1479
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1480
+ for i, image in enumerate(images)
1481
+ ]
1482
+ }
1483
+ )
1484
+
1485
+ if args.push_to_hub:
1486
+ save_model_card(
1487
+ repo_id,
1488
+ images=images,
1489
+ dataset_id=args.dataset_id,
1490
+ base_model=args.pretrained_model_name_or_path,
1491
+ train_text_encoder=args.train_text_encoder,
1492
+ prompt=args.instance_prompt,
1493
+ repo_folder=args.output_dir,
1494
+ vae_path=args.pretrained_vae_model_name_or_path,
1495
+ )
1496
+ upload_folder(
1497
+ repo_id=repo_id,
1498
+ folder_path=args.output_dir,
1499
+ commit_message="End of training",
1500
+ ignore_patterns=["step_*", "epoch_*"],
1501
+ token=args.hub_token
1502
+ )
1503
+
1504
+ accelerator.end_training()
1505
+
1506
+ if __name__ == "__main__":
1507
+ args = parse_args()
1508
+ main(args)