joachimsallstrom commited on
Commit
a0da56a
·
1 Parent(s): 2a7462a

Upload 5 files

Browse files
SDXL-LoRA-RNPD.ipynb ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "b01332d1-1384-4405-8af6-335c768da6e2",
6
+ "metadata": {},
7
+ "source": [
8
+ "## SDXL LoRA Trainer by TheLastBen https://github.com/TheLastBen/fast-stable-diffusion, if you encounter any issues, feel free to discuss them."
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "8f82bb3b-76de-4e2c-9251-df918f8f2cbe",
14
+ "metadata": {},
15
+ "source": [
16
+ "# Dependencies"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 1,
22
+ "id": "3d144e06-1f7a-467b-9cf1-452bf773f0ab",
23
+ "metadata": {
24
+ "tags": []
25
+ },
26
+ "outputs": [
27
+ {
28
+ "data": {
29
+ "application/vnd.jupyter.widget-view+json": {
30
+ "model_id": "d1e84d74d92c46f8aa78c03f50a0d0d8",
31
+ "version_major": 2,
32
+ "version_minor": 0
33
+ },
34
+ "text/plain": [
35
+ "Button(button_style='success', description='Done!', disabled=True, icon='check', style=ButtonStyle(), tooltip=…"
36
+ ]
37
+ },
38
+ "metadata": {},
39
+ "output_type": "display_data"
40
+ }
41
+ ],
42
+ "source": [
43
+ "# Install the dependencies\n",
44
+ "\n",
45
+ "force_reinstall= False\n",
46
+ "\n",
47
+ "# Set to true only if you want to install the dependencies again.\n",
48
+ "\n",
49
+ "#--------------------\n",
50
+ "with open('/dev/null', 'w') as devnull:import requests, os, time, importlib;open('/workspace/sdxllorarunpod.py', 'wb').write(requests.get('https://huggingface.co/datasets/TheLastBen/RNPD/raw/main/Scripts/sdxllorarunpod.py').content);os.chdir('/workspace');import sdxllorarunpod;importlib.reload(sdxllorarunpod);from sdxllorarunpod import *;restored=False;restoreda=False;Deps(force_reinstall)"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "markdown",
55
+ "id": "461b7686-e4aa-4fa8-ab6f-5a6acbf4c601",
56
+ "metadata": {},
57
+ "source": [
58
+ "# Download the model"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 2,
64
+ "id": "2f705bd1-35c9-49bd-84fd-03a1348cbe83",
65
+ "metadata": {
66
+ "tags": []
67
+ },
68
+ "outputs": [
69
+ {
70
+ "name": "stdout",
71
+ "output_type": "stream",
72
+ "text": [
73
+ "\u001b[1;32mUsing SDXL model\n"
74
+ ]
75
+ }
76
+ ],
77
+ "source": [
78
+ "# Run the cell to download the model\n",
79
+ "\n",
80
+ "#-------------\n",
81
+ "MODEL_NAMExl=dls_xlf(\"\", \"\", \"\")"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "markdown",
86
+ "id": "8e22327b-e0c3-424c-82e1-fb7f8a815c0b",
87
+ "metadata": {},
88
+ "source": [
89
+ "# Create/Load a Session"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 3,
95
+ "id": "ac69c221-205a-40d2-b42e-6c8d515a43cc",
96
+ "metadata": {
97
+ "tags": []
98
+ },
99
+ "outputs": [
100
+ {
101
+ "name": "stdout",
102
+ "output_type": "stream",
103
+ "text": [
104
+ "\u001b[1;32mCreating session...\n",
105
+ "\u001b[1;32mSession created, proceed to uploading instance images\n"
106
+ ]
107
+ }
108
+ ],
109
+ "source": [
110
+ "Session_Name = \"aether_skin_230808_SDXL_LoRA_128_dim_50_epochs\"\n",
111
+ "\n",
112
+ "# Enter the session name, it if it exists, it will load it, otherwise it'll create an new session.\n",
113
+ "\n",
114
+ "#-----------------\n",
115
+ "[WORKSPACE, Session_Name, INSTANCE_NAME, OUTPUT_DIR, SESSION_DIR, INSTANCE_DIR, CAPTIONS_DIR, MDLPTH, MODEL_NAMExl]=sess_xl(Session_Name, MODEL_NAMExl if 'MODEL_NAMExl' in locals() else \"\")"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "markdown",
120
+ "id": "5d239e77-f7fd-404b-8006-081f15326412",
121
+ "metadata": {},
122
+ "source": [
123
+ "# Train LoRA"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": null,
129
+ "id": "c54a7335-8402-42f2-9a71-9da99f6ea604",
130
+ "metadata": {
131
+ "tags": []
132
+ },
133
+ "outputs": [
134
+ {
135
+ "name": "stdout",
136
+ "output_type": "stream",
137
+ "text": [
138
+ "\u001b[34m'########:'########:::::'###::::'####:'##::: ##:'####:'##::: ##::'######:::\n",
139
+ "... ##..:: ##.... ##:::'## ##:::. ##:: ###:: ##:. ##:: ###:: ##:'##... ##::\n",
140
+ "::: ##:::: ##:::: ##::'##:. ##::: ##:: ####: ##:: ##:: ####: ##: ##:::..:::\n",
141
+ "::: ##:::: ########::'##:::. ##:: ##:: ## ## ##:: ##:: ## ## ##: ##::'####:\n",
142
+ "::: ##:::: ##.. ##::: #########:: ##:: ##. ####:: ##:: ##. ####: ##::: ##::\n",
143
+ "::: ##:::: ##::. ##:: ##.... ##:: ##:: ##:. ###:: ##:: ##:. ###: ##::: ##::\n",
144
+ "::: ##:::: ##:::. ##: ##:::: ##:'####: ##::. ##:'####: ##::. ##:. ######:::\n",
145
+ ":::..:::::..:::::..::..:::::..::....::..::::..::....::..::::..:::......::::\n",
146
+ "\u001b[0m\n"
147
+ ]
148
+ },
149
+ {
150
+ "name": "stderr",
151
+ "output_type": "stream",
152
+ "text": [
153
+ "Progress: 71%|███████ | 676/950 [06:22<02:23, 1.91it/s, loss=0.245, lr=5.75e-7] "
154
+ ]
155
+ }
156
+ ],
157
+ "source": [
158
+ "Resume_Training= False\n",
159
+ "\n",
160
+ "# If you're not satisfied with the result, Set to True, run again the cell and it will continue training the current model.\n",
161
+ "\n",
162
+ "\n",
163
+ "Training_Epochs= 50\n",
164
+ "\n",
165
+ "# Epoch = Number of steps/images.\n",
166
+ "\n",
167
+ "Learning_Rate= \"3e-6\"\n",
168
+ "\n",
169
+ "# keep it between 1e-6 and 6e-6\n",
170
+ "\n",
171
+ "\n",
172
+ "External_Captions= True\n",
173
+ "\n",
174
+ "# Load the captions from a text file for each instance image.\n",
175
+ "\n",
176
+ "\n",
177
+ "LoRA_Dim = 128\n",
178
+ "\n",
179
+ "# Dimension of the LoRa model, between 64 and 128 is good enough.\n",
180
+ "\n",
181
+ "\n",
182
+ "Resolution= 1024\n",
183
+ "\n",
184
+ "# 1024 is the native resolution.\n",
185
+ "\n",
186
+ "\n",
187
+ "Save_VRAM = False\n",
188
+ "\n",
189
+ "# Use as low as 9.7GB VRAM with Dim = 64, but slightly slower training.\n",
190
+ "\n",
191
+ "#-----------------\n",
192
+ "dbtrainxl(Resume_Training, Training_Epochs, Learning_Rate, LoRA_Dim, False, Resolution, MODEL_NAMExl, SESSION_DIR, INSTANCE_DIR, CAPTIONS_DIR, External_Captions, INSTANCE_NAME, Session_Name, OUTPUT_DIR, 0.03, Save_VRAM)"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "markdown",
197
+ "id": "e2751798-508e-47ad-8e54-95188bdab051",
198
+ "metadata": {
199
+ "jp-MarkdownHeadingCollapsed": true,
200
+ "tags": []
201
+ },
202
+ "source": [
203
+ "# Test the Trained Model"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "markdown",
208
+ "id": "d1bc48d6-1526-44c6-ab7c-cc1538c7f61c",
209
+ "metadata": {},
210
+ "source": [
211
+ "# ComfyUI"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": null,
217
+ "id": "26272665-16de-4042-a7a4-6b9205ff3309",
218
+ "metadata": {
219
+ "scrolled": true,
220
+ "tags": []
221
+ },
222
+ "outputs": [],
223
+ "source": [
224
+ "Args=\"--listen --port 3000\"\n",
225
+ "\n",
226
+ "\n",
227
+ "Download_SDXL_Model= True\n",
228
+ "\n",
229
+ "\n",
230
+ "Huggingface_token_optional= \"\"\n",
231
+ "\n",
232
+ "# Restore your backed-up Comfy folder by entering your huggingface token, leave it empty to start fresh or continue with the existing sd folder (if any).\n",
233
+ "\n",
234
+ "#--------------------\n",
235
+ "restored=sdcmff(Huggingface_token_optional, MDLPTH, Download_SDXL_Model, restored)\n",
236
+ "!python /workspace/ComfyUI/main.py $Args"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "markdown",
241
+ "id": "410520ca-7352-4fc4-907b-cb53f661074e",
242
+ "metadata": {},
243
+ "source": [
244
+ "# A1111"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "code",
249
+ "execution_count": null,
250
+ "id": "351f18d5-f723-4d25-b1ae-1296a22c6d8c",
251
+ "metadata": {
252
+ "tags": []
253
+ },
254
+ "outputs": [],
255
+ "source": [
256
+ "User = \"\"\n",
257
+ "\n",
258
+ "Password= \"\"\n",
259
+ "\n",
260
+ "# Add credentials to your Gradio interface (optional).\n",
261
+ "\n",
262
+ "Download_SDXL_Model= True\n",
263
+ "\n",
264
+ "\n",
265
+ "Huggingface_token_optional= \"\"\n",
266
+ "\n",
267
+ "# Restore your backed-up SD folder by entering your huggingface token, leave it empty to start fresh or continue with the existing sd folder (if any).\n",
268
+ "\n",
269
+ "#-----------------\n",
270
+ "configf, restoreda=test(MDLPTH, User, Password, Huggingface_token_optional, Download_SDXL_Model, restoreda)\n",
271
+ "!python /workspace/sd/stable-diffusion-webui/webui.py $configf"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "markdown",
276
+ "id": "093d64a7-3d4e-4197-8075-4ed11c7f0ae8",
277
+ "metadata": {},
278
+ "source": [
279
+ "# Free up space"
280
+ ]
281
+ },
282
+ {
283
+ "cell_type": "code",
284
+ "execution_count": null,
285
+ "id": "370ba58a-d58d-4a80-9575-8c6e094e2626",
286
+ "metadata": {},
287
+ "outputs": [],
288
+ "source": [
289
+ "# Display a list of sessions from which you can remove any session you don't need anymore\n",
290
+ "\n",
291
+ "#-------------------------\n",
292
+ "clean()"
293
+ ]
294
+ }
295
+ ],
296
+ "metadata": {
297
+ "kernelspec": {
298
+ "display_name": "Python 3 (ipykernel)",
299
+ "language": "python",
300
+ "name": "python3"
301
+ },
302
+ "language_info": {
303
+ "codemirror_mode": {
304
+ "name": "ipython",
305
+ "version": 3
306
+ },
307
+ "file_extension": ".py",
308
+ "mimetype": "text/x-python",
309
+ "name": "python",
310
+ "nbconvert_exporter": "python",
311
+ "pygments_lexer": "ipython3",
312
+ "version": "3.10.12"
313
+ }
314
+ },
315
+ "nbformat": 4,
316
+ "nbformat_minor": 5
317
+ }
lora_sdxl.py ADDED
@@ -0,0 +1,1128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+
6
+ import math
7
+ import os
8
+ from typing import Dict, List, Optional, Tuple, Type, Union
9
+ from diffusers import AutoencoderKL
10
+ from transformers import CLIPTextModel
11
+ import numpy as np
12
+ import torch
13
+ import re
14
+
15
+
16
+ RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
17
+
18
+
19
+ class LoRAModule(torch.nn.Module):
20
+ """
21
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ lora_name,
27
+ org_module: torch.nn.Module,
28
+ multiplier=1.0,
29
+ lora_dim=4,
30
+ alpha=1,
31
+ dropout=None,
32
+ rank_dropout=None,
33
+ module_dropout=None,
34
+ ):
35
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
36
+ super().__init__()
37
+ self.lora_name = lora_name
38
+
39
+ if org_module.__class__.__name__ == "Conv2d":
40
+ in_dim = org_module.in_channels
41
+ out_dim = org_module.out_channels
42
+ else:
43
+ in_dim = org_module.in_features
44
+ out_dim = org_module.out_features
45
+
46
+ # if limit_rank:
47
+ # self.lora_dim = min(lora_dim, in_dim, out_dim)
48
+ # if self.lora_dim != lora_dim:
49
+ # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
50
+ # else:
51
+ self.lora_dim = lora_dim
52
+
53
+ if org_module.__class__.__name__ == "Conv2d":
54
+ kernel_size = org_module.kernel_size
55
+ stride = org_module.stride
56
+ padding = org_module.padding
57
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
58
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
59
+ else:
60
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
61
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
62
+
63
+ if type(alpha) == torch.Tensor:
64
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
65
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
66
+ self.scale = alpha / self.lora_dim
67
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
68
+
69
+ # same as microsoft's
70
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
71
+ torch.nn.init.zeros_(self.lora_up.weight)
72
+
73
+ self.multiplier = multiplier
74
+ self.org_module = org_module # remove in applying
75
+ self.dropout = dropout
76
+ self.rank_dropout = rank_dropout
77
+ self.module_dropout = module_dropout
78
+
79
+ def apply_to(self):
80
+ self.org_forward = self.org_module.forward
81
+ self.org_module.forward = self.forward
82
+ del self.org_module
83
+
84
+ def forward(self, x):
85
+ org_forwarded = self.org_forward(x)
86
+
87
+ # module dropout
88
+ if self.module_dropout is not None and self.training:
89
+ if torch.rand(1) < self.module_dropout:
90
+ return org_forwarded
91
+
92
+ lx = self.lora_down(x)
93
+
94
+ # normal dropout
95
+ if self.dropout is not None and self.training:
96
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
97
+
98
+ # rank dropout
99
+ if self.rank_dropout is not None and self.training:
100
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
101
+ if len(lx.size()) == 3:
102
+ mask = mask.unsqueeze(1) # for Text Encoder
103
+ elif len(lx.size()) == 4:
104
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
105
+ lx = lx * mask
106
+
107
+ # scaling for rank dropout: treat as if the rank is changed
108
+ # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
109
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
110
+ else:
111
+ scale = self.scale
112
+
113
+ lx = self.lora_up(lx)
114
+
115
+ return org_forwarded + lx * self.multiplier * scale
116
+
117
+
118
+ class LoRAInfModule(LoRAModule):
119
+ def __init__(
120
+ self,
121
+ lora_name,
122
+ org_module: torch.nn.Module,
123
+ multiplier=1.0,
124
+ lora_dim=4,
125
+ alpha=1,
126
+ **kwargs,
127
+ ):
128
+ # no dropout for inference
129
+ super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
130
+
131
+ self.org_module_ref = [org_module] # 後から参照できるように
132
+ self.enabled = True
133
+
134
+ # check regional or not by lora_name
135
+ self.text_encoder = False
136
+ if lora_name.startswith("lora_te_"):
137
+ self.regional = False
138
+ self.use_sub_prompt = True
139
+ self.text_encoder = True
140
+ elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
141
+ self.regional = False
142
+ self.use_sub_prompt = True
143
+ elif "time_emb" in lora_name:
144
+ self.regional = False
145
+ self.use_sub_prompt = False
146
+ else:
147
+ self.regional = True
148
+ self.use_sub_prompt = False
149
+
150
+ self.network: LoRANetwork = None
151
+
152
+ def set_network(self, network):
153
+ self.network = network
154
+
155
+ # freezeしてマージする
156
+ def merge_to(self, sd, dtype, device):
157
+ # get up/down weight
158
+ up_weight = sd["lora_up.weight"].to(torch.float).to(device)
159
+ down_weight = sd["lora_down.weight"].to(torch.float).to(device)
160
+
161
+ # extract weight from org_module
162
+ org_sd = self.org_module.state_dict()
163
+ weight = org_sd["weight"].to(torch.float)
164
+
165
+ # merge weight
166
+ if len(weight.size()) == 2:
167
+ # linear
168
+ weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
169
+ elif down_weight.size()[2:4] == (1, 1):
170
+ # conv2d 1x1
171
+ weight = (
172
+ weight
173
+ + self.multiplier
174
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
175
+ * self.scale
176
+ )
177
+ else:
178
+ # conv2d 3x3
179
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
180
+ # print(conved.size(), weight.size(), module.stride, module.padding)
181
+ weight = weight + self.multiplier * conved * self.scale
182
+
183
+ # set weight to org_module
184
+ org_sd["weight"] = weight.to(dtype)
185
+ self.org_module.load_state_dict(org_sd)
186
+
187
+ # 復元できるマージのため、このモジュールのweightを返す
188
+ def get_weight(self, multiplier=None):
189
+ if multiplier is None:
190
+ multiplier = self.multiplier
191
+
192
+ # get up/down weight from module
193
+ up_weight = self.lora_up.weight.to(torch.float)
194
+ down_weight = self.lora_down.weight.to(torch.float)
195
+
196
+ # pre-calculated weight
197
+ if len(down_weight.size()) == 2:
198
+ # linear
199
+ weight = self.multiplier * (up_weight @ down_weight) * self.scale
200
+ elif down_weight.size()[2:4] == (1, 1):
201
+ # conv2d 1x1
202
+ weight = (
203
+ self.multiplier
204
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
205
+ * self.scale
206
+ )
207
+ else:
208
+ # conv2d 3x3
209
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
210
+ weight = self.multiplier * conved * self.scale
211
+
212
+ return weight
213
+
214
+ def set_region(self, region):
215
+ self.region = region
216
+ self.region_mask = None
217
+
218
+ def default_forward(self, x):
219
+ # print("default_forward", self.lora_name, x.size())
220
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
221
+
222
+ def forward(self, x):
223
+ if not self.enabled:
224
+ return self.org_forward(x)
225
+
226
+ if self.network is None or self.network.sub_prompt_index is None:
227
+ return self.default_forward(x)
228
+ if not self.regional and not self.use_sub_prompt:
229
+ return self.default_forward(x)
230
+
231
+ if self.regional:
232
+ return self.regional_forward(x)
233
+ else:
234
+ return self.sub_prompt_forward(x)
235
+
236
+ def get_mask_for_x(self, x):
237
+ # calculate size from shape of x
238
+ if len(x.size()) == 4:
239
+ h, w = x.size()[2:4]
240
+ area = h * w
241
+ else:
242
+ area = x.size()[1]
243
+
244
+ mask = self.network.mask_dic[area]
245
+ if mask is None:
246
+ raise ValueError(f"mask is None for resolution {area}")
247
+ if len(x.size()) != 4:
248
+ mask = torch.reshape(mask, (1, -1, 1))
249
+ return mask
250
+
251
+ def regional_forward(self, x):
252
+ if "attn2_to_out" in self.lora_name:
253
+ return self.to_out_forward(x)
254
+
255
+ if self.network.mask_dic is None: # sub_prompt_index >= 3
256
+ return self.default_forward(x)
257
+
258
+ # apply mask for LoRA result
259
+ lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
260
+ mask = self.get_mask_for_x(lx)
261
+ # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
262
+ lx = lx * mask
263
+
264
+ x = self.org_forward(x)
265
+ x = x + lx
266
+
267
+ if "attn2_to_q" in self.lora_name and self.network.is_last_network:
268
+ x = self.postp_to_q(x)
269
+
270
+ return x
271
+
272
+ def postp_to_q(self, x):
273
+ # repeat x to num_sub_prompts
274
+ has_real_uncond = x.size()[0] // self.network.batch_size == 3
275
+ qc = self.network.batch_size # uncond
276
+ qc += self.network.batch_size * self.network.num_sub_prompts # cond
277
+ if has_real_uncond:
278
+ qc += self.network.batch_size # real_uncond
279
+
280
+ query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
281
+ query[: self.network.batch_size] = x[: self.network.batch_size]
282
+
283
+ for i in range(self.network.batch_size):
284
+ qi = self.network.batch_size + i * self.network.num_sub_prompts
285
+ query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
286
+
287
+ if has_real_uncond:
288
+ query[-self.network.batch_size :] = x[-self.network.batch_size :]
289
+
290
+ # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
291
+ return query
292
+
293
+ def sub_prompt_forward(self, x):
294
+ if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA
295
+ return self.org_forward(x)
296
+
297
+ emb_idx = self.network.sub_prompt_index
298
+ if not self.text_encoder:
299
+ emb_idx += self.network.batch_size
300
+
301
+ # apply sub prompt of X
302
+ lx = x[emb_idx :: self.network.num_sub_prompts]
303
+ lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
304
+
305
+ # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
306
+
307
+ x = self.org_forward(x)
308
+ x[emb_idx :: self.network.num_sub_prompts] += lx
309
+
310
+ return x
311
+
312
+ def to_out_forward(self, x):
313
+ # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
314
+
315
+ if self.network.is_last_network:
316
+ masks = [None] * self.network.num_sub_prompts
317
+ self.network.shared[self.lora_name] = (None, masks)
318
+ else:
319
+ lx, masks = self.network.shared[self.lora_name]
320
+
321
+ # call own LoRA
322
+ x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
323
+ lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
324
+
325
+ if self.network.is_last_network:
326
+ lx = torch.zeros(
327
+ (self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
328
+ )
329
+ self.network.shared[self.lora_name] = (lx, masks)
330
+
331
+ # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
332
+ lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
333
+ masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
334
+
335
+ # if not last network, return x and masks
336
+ x = self.org_forward(x)
337
+ if not self.network.is_last_network:
338
+ return x
339
+
340
+ lx, masks = self.network.shared.pop(self.lora_name)
341
+
342
+ # if last network, combine separated x with mask weighted sum
343
+ has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
344
+
345
+ out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
346
+ out[: self.network.batch_size] = x[: self.network.batch_size] # uncond
347
+ if has_real_uncond:
348
+ out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
349
+
350
+ # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
351
+ # for i in range(len(masks)):
352
+ # if masks[i] is None:
353
+ # masks[i] = torch.zeros_like(masks[-1])
354
+
355
+ mask = torch.cat(masks)
356
+ mask_sum = torch.sum(mask, dim=0) + 1e-4
357
+ for i in range(self.network.batch_size):
358
+ # 1枚の画像ごとに処理する
359
+ lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
360
+ lx1 = lx1 * mask
361
+ lx1 = torch.sum(lx1, dim=0)
362
+
363
+ xi = self.network.batch_size + i * self.network.num_sub_prompts
364
+ x1 = x[xi : xi + self.network.num_sub_prompts]
365
+ x1 = x1 * mask
366
+ x1 = torch.sum(x1, dim=0)
367
+ x1 = x1 / mask_sum
368
+
369
+ x1 = x1 + lx1
370
+ out[self.network.batch_size + i] = x1
371
+
372
+ # print("to_out_forward", x.size(), out.size(), has_real_uncond)
373
+ return out
374
+
375
+
376
+ def parse_block_lr_kwargs(nw_kwargs):
377
+ down_lr_weight = nw_kwargs.get("down_lr_weight", None)
378
+ mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
379
+ up_lr_weight = nw_kwargs.get("up_lr_weight", None)
380
+
381
+ # 以上のいずれにも設定がない場合は無効としてNoneを返す
382
+ if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
383
+ return None, None, None
384
+
385
+ # extract learning rate weight for each block
386
+ if down_lr_weight is not None:
387
+ # if some parameters are not set, use zero
388
+ if "," in down_lr_weight:
389
+ down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
390
+
391
+ if mid_lr_weight is not None:
392
+ mid_lr_weight = float(mid_lr_weight)
393
+
394
+ if up_lr_weight is not None:
395
+ if "," in up_lr_weight:
396
+ up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
397
+
398
+ down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
399
+ down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
400
+ )
401
+
402
+ return down_lr_weight, mid_lr_weight, up_lr_weight
403
+
404
+
405
+ def create_network(
406
+ multiplier: float,
407
+ network_dim: Optional[int],
408
+ network_alpha: Optional[float],
409
+ unet,
410
+ neuron_dropout: Optional[float] = None,
411
+ **kwargs,
412
+ ):
413
+ if network_dim is None:
414
+ network_dim = 4 # default
415
+ if network_alpha is None:
416
+ network_alpha = 1.0
417
+
418
+ # extract dim/alpha for conv2d, and block dim
419
+ conv_dim = kwargs.get("conv_dim", None)
420
+ conv_alpha = kwargs.get("conv_alpha", None)
421
+ if conv_dim is not None:
422
+ conv_dim = int(conv_dim)
423
+ if conv_alpha is None:
424
+ conv_alpha = 1.0
425
+ else:
426
+ conv_alpha = float(conv_alpha)
427
+
428
+ # block dim/alpha/lr
429
+ block_dims = kwargs.get("block_dims", None)
430
+ down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
431
+
432
+ # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
433
+ if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
434
+ block_alphas = kwargs.get("block_alphas", None)
435
+ conv_block_dims = kwargs.get("conv_block_dims", None)
436
+ conv_block_alphas = kwargs.get("conv_block_alphas", None)
437
+
438
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
439
+ block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
440
+ )
441
+
442
+ # remove block dim/alpha without learning rate
443
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
444
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
445
+ )
446
+
447
+ else:
448
+ block_alphas = None
449
+ conv_block_dims = None
450
+ conv_block_alphas = None
451
+
452
+ # rank/module dropout
453
+ rank_dropout = kwargs.get("rank_dropout", None)
454
+ if rank_dropout is not None:
455
+ rank_dropout = float(rank_dropout)
456
+ module_dropout = kwargs.get("module_dropout", None)
457
+ if module_dropout is not None:
458
+ module_dropout = float(module_dropout)
459
+
460
+ # すごく引数が多いな ( ^ω^)・・・
461
+ network = LoRANetwork(
462
+ unet,
463
+ multiplier=multiplier,
464
+ lora_dim=network_dim,
465
+ alpha=network_alpha,
466
+ dropout=neuron_dropout,
467
+ rank_dropout=rank_dropout,
468
+ module_dropout=module_dropout,
469
+ conv_lora_dim=conv_dim,
470
+ conv_alpha=conv_alpha,
471
+ block_dims=block_dims,
472
+ block_alphas=block_alphas,
473
+ conv_block_dims=conv_block_dims,
474
+ conv_block_alphas=conv_block_alphas,
475
+ varbose=True,
476
+ )
477
+
478
+ if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
479
+ network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
480
+
481
+ return network
482
+
483
+
484
+ # このメソッドは外部から呼び出される可能性を考慮しておく
485
+ # network_dim, network_alpha にはデフォルト値が入っている。
486
+ # block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
487
+ # conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
488
+ def get_block_dims_and_alphas(
489
+ block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
490
+ ):
491
+ num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
492
+
493
+ def parse_ints(s):
494
+ return [int(i) for i in s.split(",")]
495
+
496
+ def parse_floats(s):
497
+ return [float(i) for i in s.split(",")]
498
+
499
+ # block_dimsとblock_alphasをパースする。必ず値が入る
500
+ if block_dims is not None:
501
+ block_dims = parse_ints(block_dims)
502
+ assert (
503
+ len(block_dims) == num_total_blocks
504
+ ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
505
+ else:
506
+ print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
507
+ block_dims = [network_dim] * num_total_blocks
508
+
509
+ if block_alphas is not None:
510
+ block_alphas = parse_floats(block_alphas)
511
+ assert (
512
+ len(block_alphas) == num_total_blocks
513
+ ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
514
+ else:
515
+ print(
516
+ f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
517
+ )
518
+ block_alphas = [network_alpha] * num_total_blocks
519
+
520
+ # conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
521
+ if conv_block_dims is not None:
522
+ conv_block_dims = parse_ints(conv_block_dims)
523
+ assert (
524
+ len(conv_block_dims) == num_total_blocks
525
+ ), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
526
+
527
+ if conv_block_alphas is not None:
528
+ conv_block_alphas = parse_floats(conv_block_alphas)
529
+ assert (
530
+ len(conv_block_alphas) == num_total_blocks
531
+ ), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
532
+ else:
533
+ if conv_alpha is None:
534
+ conv_alpha = 1.0
535
+ print(
536
+ f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
537
+ )
538
+ conv_block_alphas = [conv_alpha] * num_total_blocks
539
+ else:
540
+ if conv_dim is not None:
541
+ print(
542
+ f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
543
+ )
544
+ conv_block_dims = [conv_dim] * num_total_blocks
545
+ conv_block_alphas = [conv_alpha] * num_total_blocks
546
+ else:
547
+ conv_block_dims = None
548
+ conv_block_alphas = None
549
+
550
+ return block_dims, block_alphas, conv_block_dims, conv_block_alphas
551
+
552
+
553
+ # 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
554
+ def get_block_lr_weight(
555
+ down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
556
+ ) -> Tuple[List[float], List[float], List[float]]:
557
+ # パラメータ未指定時は何もせず、今までと同じ動作とする
558
+ if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
559
+ return None, None, None
560
+
561
+ max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
562
+
563
+ def get_list(name_with_suffix) -> List[float]:
564
+ import math
565
+
566
+ tokens = name_with_suffix.split("+")
567
+ name = tokens[0]
568
+ base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
569
+
570
+ if name == "cosine":
571
+ return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
572
+ elif name == "sine":
573
+ return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
574
+ elif name == "linear":
575
+ return [i / (max_len - 1) + base_lr for i in range(max_len)]
576
+ elif name == "reverse_linear":
577
+ return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
578
+ elif name == "zeros":
579
+ return [0.0 + base_lr] * max_len
580
+ else:
581
+ print(
582
+ "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
583
+ % (name)
584
+ )
585
+ return None
586
+
587
+ if type(down_lr_weight) == str:
588
+ down_lr_weight = get_list(down_lr_weight)
589
+ if type(up_lr_weight) == str:
590
+ up_lr_weight = get_list(up_lr_weight)
591
+
592
+ if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
593
+ print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
594
+ print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
595
+ up_lr_weight = up_lr_weight[:max_len]
596
+ down_lr_weight = down_lr_weight[:max_len]
597
+
598
+ if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
599
+ print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
600
+ print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
601
+
602
+ if down_lr_weight != None and len(down_lr_weight) < max_len:
603
+ down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
604
+ if up_lr_weight != None and len(up_lr_weight) < max_len:
605
+ up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
606
+
607
+ if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
608
+ print("apply block learning rate / 階層別学習率を適用します。")
609
+ if down_lr_weight != None:
610
+ down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
611
+ print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
612
+ else:
613
+ print("down_lr_weight: all 1.0, すべて1.0")
614
+
615
+ if mid_lr_weight != None:
616
+ mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
617
+ print("mid_lr_weight:", mid_lr_weight)
618
+ else:
619
+ print("mid_lr_weight: 1.0")
620
+
621
+ if up_lr_weight != None:
622
+ up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
623
+ print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
624
+ else:
625
+ print("up_lr_weight: all 1.0, すべて1.0")
626
+
627
+ return down_lr_weight, mid_lr_weight, up_lr_weight
628
+
629
+
630
+ # lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
631
+ def remove_block_dims_and_alphas(
632
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
633
+ ):
634
+ # set 0 to block dim without learning rate to remove the block
635
+ if down_lr_weight != None:
636
+ for i, lr in enumerate(down_lr_weight):
637
+ if lr == 0:
638
+ block_dims[i] = 0
639
+ if conv_block_dims is not None:
640
+ conv_block_dims[i] = 0
641
+ if mid_lr_weight != None:
642
+ if mid_lr_weight == 0:
643
+ block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
644
+ if conv_block_dims is not None:
645
+ conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
646
+ if up_lr_weight != None:
647
+ for i, lr in enumerate(up_lr_weight):
648
+ if lr == 0:
649
+ block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
650
+ if conv_block_dims is not None:
651
+ conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
652
+
653
+ return block_dims, block_alphas, conv_block_dims, conv_block_alphas
654
+
655
+
656
+ # 外部から呼び出す可能性を考慮しておく
657
+ def get_block_index(lora_name: str) -> int:
658
+ block_idx = -1 # invalid lora name
659
+
660
+ m = RE_UPDOWN.search(lora_name)
661
+ if m:
662
+ g = m.groups()
663
+ i = int(g[1])
664
+ j = int(g[3])
665
+ if g[2] == "resnets":
666
+ idx = 3 * i + j
667
+ elif g[2] == "attentions":
668
+ idx = 3 * i + j
669
+ elif g[2] == "upsamplers" or g[2] == "downsamplers":
670
+ idx = 3 * i + 2
671
+
672
+ if g[0] == "down":
673
+ block_idx = 1 + idx # 0に該当するLoRAは存在しない
674
+ elif g[0] == "up":
675
+ block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
676
+
677
+ elif "mid_block_" in lora_name:
678
+ block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
679
+
680
+ return block_idx
681
+
682
+
683
+ # Create network from weights for inference, weights are not loaded here (because can be merged)
684
+ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
685
+ if weights_sd is None:
686
+ if os.path.splitext(file)[1] == ".safetensors":
687
+ from safetensors.torch import load_file, safe_open
688
+
689
+ weights_sd = load_file(file)
690
+ else:
691
+ weights_sd = torch.load(file, map_location="cpu")
692
+
693
+ # get dim/alpha mapping
694
+ modules_dim = {}
695
+ modules_alpha = {}
696
+ for key, value in weights_sd.items():
697
+ if "." not in key:
698
+ continue
699
+
700
+ lora_name = key.split(".")[0]
701
+ if "alpha" in key:
702
+ modules_alpha[lora_name] = value
703
+ elif "lora_down" in key:
704
+ dim = value.size()[0]
705
+ modules_dim[lora_name] = dim
706
+ # print(lora_name, value.size(), dim)
707
+
708
+ # support old LoRA without alpha
709
+ for key in modules_dim.keys():
710
+ if key not in modules_alpha:
711
+ modules_alpha[key] = modules_dim[key]
712
+
713
+ module_class = LoRAInfModule if for_inference else LoRAModule
714
+
715
+ network = LoRANetwork(
716
+ text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
717
+ )
718
+
719
+ # block lr
720
+ down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
721
+ if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
722
+ network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
723
+
724
+ return network, weights_sd
725
+
726
+
727
+ class LoRANetwork(torch.nn.Module):
728
+ NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
729
+
730
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
731
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
732
+ LORA_PREFIX_UNET = "lora_unet"
733
+
734
+ def __init__(
735
+ self,
736
+ unet,
737
+ multiplier: float = 1.0,
738
+ lora_dim: int = 4,
739
+ alpha: float = 1,
740
+ dropout: Optional[float] = None,
741
+ rank_dropout: Optional[float] = None,
742
+ module_dropout: Optional[float] = None,
743
+ conv_lora_dim: Optional[int] = None,
744
+ conv_alpha: Optional[float] = None,
745
+ block_dims: Optional[List[int]] = None,
746
+ block_alphas: Optional[List[float]] = None,
747
+ conv_block_dims: Optional[List[int]] = None,
748
+ conv_block_alphas: Optional[List[float]] = None,
749
+ modules_dim: Optional[Dict[str, int]] = None,
750
+ modules_alpha: Optional[Dict[str, int]] = None,
751
+ module_class: Type[object] = LoRAModule,
752
+ varbose: Optional[bool] = False,
753
+ ) -> None:
754
+ """
755
+ LoRA network: すごく引数が多いが、パターンは以下の通り
756
+ 1. lora_dimとalphaを指定
757
+ 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
758
+ 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
759
+ 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
760
+ 5. modules_dimとmodules_alphaを指定 (推論用)
761
+ """
762
+ super().__init__()
763
+ self.multiplier = multiplier
764
+
765
+ self.lora_dim = lora_dim
766
+ self.alpha = alpha
767
+ self.conv_lora_dim = conv_lora_dim
768
+ self.conv_alpha = conv_alpha
769
+ self.dropout = dropout
770
+ self.rank_dropout = rank_dropout
771
+ self.module_dropout = module_dropout
772
+
773
+
774
+ # create module instances
775
+ def create_modules(
776
+ is_unet: bool,
777
+ root_module: torch.nn.Module,
778
+ target_replace_modules: List[torch.nn.Module],
779
+ ) -> List[LoRAModule]:
780
+ prefix = (
781
+ self.LORA_PREFIX_UNET
782
+ )
783
+ loras = []
784
+ skipped = []
785
+ for name, module in root_module.named_modules():
786
+ if module.__class__.__name__ in target_replace_modules:
787
+ for child_name, child_module in module.named_modules():
788
+ is_linear = child_module.__class__.__name__ == "Linear"
789
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
790
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
791
+
792
+ if is_linear or is_conv2d:
793
+ lora_name = prefix + "." + name + "." + child_name
794
+ lora_name = lora_name.replace(".", "_")
795
+
796
+ dim = None
797
+ alpha = None
798
+
799
+ if modules_dim is not None:
800
+ # モジュール指定あり
801
+ if lora_name in modules_dim:
802
+ dim = modules_dim[lora_name]
803
+ alpha = modules_alpha[lora_name]
804
+ elif is_unet and block_dims is not None:
805
+ # U-Netでblock_dims指定あり
806
+ block_idx = get_block_index(lora_name)
807
+ if is_linear or is_conv2d_1x1:
808
+ dim = block_dims[block_idx]
809
+ alpha = block_alphas[block_idx]
810
+ elif conv_block_dims is not None:
811
+ dim = conv_block_dims[block_idx]
812
+ alpha = conv_block_alphas[block_idx]
813
+ else:
814
+ # 通常、すべて対象とする
815
+ if is_linear or is_conv2d_1x1:
816
+ dim = self.lora_dim
817
+ alpha = self.alpha
818
+ elif self.conv_lora_dim is not None:
819
+ dim = self.conv_lora_dim
820
+ alpha = self.conv_alpha
821
+
822
+ if dim is None or dim == 0:
823
+ # skipした情報を出力
824
+ if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
825
+ skipped.append(lora_name)
826
+ continue
827
+
828
+ lora = module_class(
829
+ lora_name,
830
+ child_module,
831
+ self.multiplier,
832
+ dim,
833
+ alpha,
834
+ dropout=dropout,
835
+ rank_dropout=rank_dropout,
836
+ module_dropout=module_dropout,
837
+ )
838
+ loras.append(lora)
839
+ return loras, skipped
840
+
841
+ # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
842
+ target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
843
+ if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
844
+ target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
845
+
846
+ self.unet_loras, skipped_un = create_modules(True, unet, target_modules)
847
+
848
+
849
+ skipped = skipped_un
850
+
851
+ self.up_lr_weight: List[float] = None
852
+ self.down_lr_weight: List[float] = None
853
+ self.mid_lr_weight: float = None
854
+ self.block_lr = False
855
+
856
+ # assertion
857
+ names = set()
858
+ for lora in self.unet_loras:
859
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
860
+ names.add(lora.lora_name)
861
+
862
+ def set_multiplier(self, multiplier):
863
+ self.multiplier = multiplier
864
+ for lora in self.unet_loras:
865
+ lora.multiplier = self.multiplier
866
+
867
+ def load_weights(self, file):
868
+ if os.path.splitext(file)[1] == ".safetensors":
869
+ from safetensors.torch import load_file
870
+
871
+ weights_sd = load_file(file)
872
+ else:
873
+ weights_sd = torch.load(file, map_location="cpu")
874
+
875
+ info = self.load_state_dict(weights_sd, False)
876
+ return info
877
+
878
+ def apply_to(self, unet, apply_unet=True):
879
+ for lora in self.unet_loras:
880
+ lora.apply_to()
881
+ self.add_module(lora.lora_name, lora)
882
+
883
+ # マージできるかどうかを返す
884
+ def is_mergeable(self):
885
+ return True
886
+
887
+ # TODO refactor to common function with apply_to
888
+ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
889
+ apply_text_encoder = apply_unet = False
890
+ for key in weights_sd.keys():
891
+ if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
892
+ apply_text_encoder = True
893
+ elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
894
+ apply_unet = True
895
+
896
+
897
+ for lora in self.unet_loras:
898
+ sd_for_lora = {}
899
+ for key in weights_sd.keys():
900
+ if key.startswith(lora.lora_name):
901
+ sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
902
+ lora.merge_to(sd_for_lora, dtype, device)
903
+
904
+
905
+ def set_block_lr_weight(
906
+ self,
907
+ up_lr_weight: List[float] = None,
908
+ mid_lr_weight: float = None,
909
+ down_lr_weight: List[float] = None,
910
+ ):
911
+ self.block_lr = True
912
+ self.down_lr_weight = down_lr_weight
913
+ self.mid_lr_weight = mid_lr_weight
914
+ self.up_lr_weight = up_lr_weight
915
+
916
+ def get_lr_weight(self, lora: LoRAModule) -> float:
917
+ lr_weight = 1.0
918
+ block_idx = get_block_index(lora.lora_name)
919
+ if block_idx < 0:
920
+ return lr_weight
921
+
922
+ if block_idx < LoRANetwork.NUM_OF_BLOCKS:
923
+ if self.down_lr_weight != None:
924
+ lr_weight = self.down_lr_weight[block_idx]
925
+ elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
926
+ if self.mid_lr_weight != None:
927
+ lr_weight = self.mid_lr_weight
928
+ elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
929
+ if self.up_lr_weight != None:
930
+ lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
931
+
932
+ return lr_weight
933
+
934
+ def prepare_optimizer_params(self, unet_lr):
935
+ self.requires_grad_(True)
936
+ all_params = []
937
+
938
+ def enumerate_params(loras):
939
+ params = []
940
+ for lora in loras:
941
+ params.extend(lora.parameters())
942
+ return params
943
+
944
+
945
+ if self.unet_loras:
946
+ if self.block_lr:
947
+ # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
948
+ block_idx_to_lora = {}
949
+ for lora in self.unet_loras:
950
+ idx = get_block_index(lora.lora_name)
951
+ if idx not in block_idx_to_lora:
952
+ block_idx_to_lora[idx] = []
953
+ block_idx_to_lora[idx].append(lora)
954
+
955
+ # blockごとにパラメータを設定する
956
+ for idx, block_loras in block_idx_to_lora.items():
957
+ param_data = {"params": enumerate_params(block_loras)}
958
+
959
+ if unet_lr is not None:
960
+ param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
961
+ if ("lr" in param_data) and (param_data["lr"] == 0):
962
+ continue
963
+ all_params.append(param_data)
964
+
965
+ else:
966
+ param_data = {"params": enumerate_params(self.unet_loras)}
967
+ if unet_lr is not None:
968
+ param_data["lr"] = unet_lr
969
+ all_params.append(param_data)
970
+
971
+ return all_params
972
+
973
+ def enable_gradient_checkpointing(self):
974
+ # not supported
975
+ pass
976
+
977
+ def prepare_grad_etc(self, unet):
978
+ self.requires_grad_(True)
979
+
980
+ def on_epoch_start(self, unet):
981
+ self.train()
982
+
983
+ def get_trainable_params(self):
984
+ return self.parameters()
985
+
986
+ def save_weights(self, file, dtype, metadata):
987
+ if metadata is not None and len(metadata) == 0:
988
+ metadata = None
989
+
990
+ state_dict = self.state_dict()
991
+
992
+ if dtype is not None:
993
+ for key in list(state_dict.keys()):
994
+ v = state_dict[key]
995
+ v = v.detach().clone().to("cpu").to(dtype)
996
+ state_dict[key] = v
997
+
998
+ if os.path.splitext(file)[1] == ".safetensors":
999
+ from safetensors.torch import save_file
1000
+
1001
+ if metadata is None:
1002
+ metadata = {}
1003
+ save_file(state_dict, file, metadata)
1004
+ else:
1005
+ torch.save(state_dict, file)
1006
+
1007
+ # mask is a tensor with values from 0 to 1
1008
+ def set_region(self, sub_prompt_index, is_last_network, mask):
1009
+ if mask.max() == 0:
1010
+ mask = torch.ones_like(mask)
1011
+
1012
+ self.mask = mask
1013
+ self.sub_prompt_index = sub_prompt_index
1014
+ self.is_last_network = is_last_network
1015
+
1016
+ for lora in self.unet_loras:
1017
+ lora.set_network(self)
1018
+
1019
+ def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
1020
+ self.batch_size = batch_size
1021
+ self.num_sub_prompts = num_sub_prompts
1022
+ self.current_size = (height, width)
1023
+ self.shared = shared
1024
+
1025
+ # create masks
1026
+ mask = self.mask
1027
+ mask_dic = {}
1028
+ mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
1029
+ ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
1030
+ dtype = ref_weight.dtype
1031
+ device = ref_weight.device
1032
+
1033
+ def resize_add(mh, mw):
1034
+ # print(mh, mw, mh * mw)
1035
+ m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
1036
+ m = m.to(device, dtype=dtype)
1037
+ mask_dic[mh * mw] = m
1038
+
1039
+ h = height // 8
1040
+ w = width // 8
1041
+ for _ in range(4):
1042
+ resize_add(h, w)
1043
+ if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
1044
+ resize_add(h + h % 2, w + w % 2)
1045
+ h = (h + 1) // 2
1046
+ w = (w + 1) // 2
1047
+
1048
+ self.mask_dic = mask_dic
1049
+
1050
+ def backup_weights(self):
1051
+ # 重みのバックアップを行う
1052
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1053
+ for lora in loras:
1054
+ org_module = lora.org_module_ref[0]
1055
+ if not hasattr(org_module, "_lora_org_weight"):
1056
+ sd = org_module.state_dict()
1057
+ org_module._lora_org_weight = sd["weight"].detach().clone()
1058
+ org_module._lora_restored = True
1059
+
1060
+ def restore_weights(self):
1061
+ # 重みのリストアを行う
1062
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1063
+ for lora in loras:
1064
+ org_module = lora.org_module_ref[0]
1065
+ if not org_module._lora_restored:
1066
+ sd = org_module.state_dict()
1067
+ sd["weight"] = org_module._lora_org_weight
1068
+ org_module.load_state_dict(sd)
1069
+ org_module._lora_restored = True
1070
+
1071
+ def pre_calculation(self):
1072
+ # 事前計算を行う
1073
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1074
+ for lora in loras:
1075
+ org_module = lora.org_module_ref[0]
1076
+ sd = org_module.state_dict()
1077
+
1078
+ org_weight = sd["weight"]
1079
+ lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
1080
+ sd["weight"] = org_weight + lora_weight
1081
+ assert sd["weight"].shape == org_weight.shape
1082
+ org_module.load_state_dict(sd)
1083
+
1084
+ org_module._lora_restored = False
1085
+ lora.enabled = False
1086
+
1087
+ def apply_max_norm_regularization(self, max_norm_value, device):
1088
+ downkeys = []
1089
+ upkeys = []
1090
+ alphakeys = []
1091
+ norms = []
1092
+ keys_scaled = 0
1093
+
1094
+ state_dict = self.state_dict()
1095
+ for key in state_dict.keys():
1096
+ if "lora_down" in key and "weight" in key:
1097
+ downkeys.append(key)
1098
+ upkeys.append(key.replace("lora_down", "lora_up"))
1099
+ alphakeys.append(key.replace("lora_down.weight", "alpha"))
1100
+
1101
+ for i in range(len(downkeys)):
1102
+ down = state_dict[downkeys[i]].to(device)
1103
+ up = state_dict[upkeys[i]].to(device)
1104
+ alpha = state_dict[alphakeys[i]].to(device)
1105
+ dim = down.shape[0]
1106
+ scale = alpha / dim
1107
+
1108
+ if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
1109
+ updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
1110
+ elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
1111
+ updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
1112
+ else:
1113
+ updown = up @ down
1114
+
1115
+ updown *= scale
1116
+
1117
+ norm = updown.norm().clamp(min=max_norm_value / 2)
1118
+ desired = torch.clamp(norm, max=max_norm_value)
1119
+ ratio = desired.cpu() / norm.cpu()
1120
+ sqrt_ratio = ratio**0.5
1121
+ if ratio != 1:
1122
+ keys_scaled += 1
1123
+ state_dict[upkeys[i]] *= sqrt_ratio
1124
+ state_dict[downkeys[i]] *= sqrt_ratio
1125
+ scalednorm = updown.norm() * ratio
1126
+ norms.append(scalednorm.item())
1127
+
1128
+ return keys_scaled, sum(norms) / len(norms), max(norms)
mainrunpodA1111.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from IPython.display import clear_output
3
+ from subprocess import call, getoutput, Popen, run
4
+ import time
5
+ import ipywidgets as widgets
6
+ import requests
7
+ import sys
8
+ import fileinput
9
+ from torch.hub import download_url_to_file
10
+ from urllib.parse import urlparse, parse_qs, unquote
11
+ import re
12
+ import six
13
+
14
+ from urllib.request import urlopen, Request
15
+ import tempfile
16
+ from tqdm import tqdm
17
+
18
+
19
+
20
+
21
+ def Deps(force_reinstall):
22
+
23
+ if not force_reinstall and os.path.exists('/usr/local/lib/python3.10/dist-packages/safetensors'):
24
+ ntbks()
25
+ print('Modules and notebooks updated, dependencies already installed')
26
+ os.environ['TORCH_HOME'] = '/workspace/cache/torch'
27
+ os.environ['PYTHONWARNINGS'] = 'ignore'
28
+ else:
29
+ call('pip install --root-user-action=ignore --disable-pip-version-check --no-deps -qq gdown PyWavelets numpy==1.23.5 accelerate==0.12.0 --force-reinstall', shell=True, stdout=open('/dev/null', 'w'))
30
+ ntbks()
31
+ if os.path.exists('deps'):
32
+ call("rm -r deps", shell=True)
33
+ if os.path.exists('diffusers'):
34
+ call("rm -r diffusers", shell=True)
35
+ call('mkdir deps', shell=True)
36
+ if not os.path.exists('cache'):
37
+ call('mkdir cache', shell=True)
38
+ os.chdir('deps')
39
+ dwn("https://huggingface.co/TheLastBen/dependencies/resolve/main/rnpddeps-t2.tar.zst", "/workspace/deps/rnpddeps-t2.tar.zst", "Installing dependencies")
40
+ call('tar -C / --zstd -xf rnpddeps-t2.tar.zst', shell=True, stdout=open('/dev/null', 'w'))
41
+ call("sed -i 's@~/.cache@/workspace/cache@' /usr/local/lib/python3.10/dist-packages/transformers/utils/hub.py", shell=True)
42
+ os.chdir('/workspace')
43
+ call("git clone --depth 1 -q --branch main https://github.com/TheLastBen/diffusers", shell=True, stdout=open('/dev/null', 'w'))
44
+ #call('pip install --root-user-action=ignore --disable-pip-version-check -qq gradio==3.41.0', shell=True, stdout=open('/dev/null', 'w'))
45
+ call("rm -r deps", shell=True)
46
+ os.chdir('/workspace')
47
+ os.environ['TORCH_HOME'] = '/workspace/cache/torch'
48
+ os.environ['PYTHONWARNINGS'] = 'ignore'
49
+ call("sed -i 's@text = _formatwarnmsg(msg)@text =\"\"@g' /usr/lib/python3.10/warnings.py", shell=True)
50
+ clear_output()
51
+
52
+ done()
53
+
54
+
55
+ def dwn(url, dst, msg):
56
+ file_size = None
57
+ req = Request(url, headers={"User-Agent": "torch.hub"})
58
+ u = urlopen(req)
59
+ meta = u.info()
60
+ if hasattr(meta, 'getheaders'):
61
+ content_length = meta.getheaders("Content-Length")
62
+ else:
63
+ content_length = meta.get_all("Content-Length")
64
+ if content_length is not None and len(content_length) > 0:
65
+ file_size = int(content_length[0])
66
+
67
+ with tqdm(total=file_size, disable=False, mininterval=0.5,
68
+ bar_format=msg+' |{bar:20}| {percentage:3.0f}%') as pbar:
69
+ with open(dst, "wb") as f:
70
+ while True:
71
+ buffer = u.read(8192)
72
+ if len(buffer) == 0:
73
+ break
74
+ f.write(buffer)
75
+ pbar.update(len(buffer))
76
+ f.close()
77
+
78
+
79
+ def ntbks():
80
+
81
+ os.chdir('/workspace')
82
+ if not os.path.exists('Latest_Notebooks'):
83
+ call('mkdir Latest_Notebooks', shell=True)
84
+ else:
85
+ call('rm -r Latest_Notebooks', shell=True)
86
+ call('mkdir Latest_Notebooks', shell=True)
87
+ os.chdir('/workspace/Latest_Notebooks')
88
+ call('wget -q -i https://huggingface.co/datasets/TheLastBen/RNPD/raw/main/Notebooks.txt', shell=True)
89
+ call('rm Notebooks.txt', shell=True)
90
+ os.chdir('/workspace')
91
+
92
+
93
+ def repo(Huggingface_token_optional):
94
+
95
+ from slugify import slugify
96
+ from huggingface_hub import HfApi, CommitOperationAdd, create_repo
97
+
98
+ os.chdir('/workspace')
99
+ if Huggingface_token_optional!="":
100
+ username = HfApi().whoami(Huggingface_token_optional)["name"]
101
+ backup=f"https://huggingface.co/datasets/{username}/fast-stable-diffusion/resolve/main/sd_backup_rnpd.tar.zst"
102
+ headers = {"Authorization": f"Bearer {Huggingface_token_optional}"}
103
+ response = requests.head(backup, headers=headers)
104
+ if response.status_code == 302:
105
+ print('Restoring the SD folder...')
106
+ open('/workspace/sd_backup_rnpd.tar.zst', 'wb').write(requests.get(backup, headers=headers).content)
107
+ call('tar --zstd -xf sd_backup_rnpd.tar.zst', shell=True)
108
+ call('rm sd_backup_rnpd.tar.zst', shell=True)
109
+ else:
110
+ print('Backup not found, using a fresh/existing repo...')
111
+ time.sleep(2)
112
+ if not os.path.exists('/workspace/sd/stablediffusiond'): #reset later
113
+ call('wget -q -O sd_mrep.tar.zst https://huggingface.co/TheLastBen/dependencies/resolve/main/sd_mrep.tar.zst', shell=True)
114
+ call('tar --zstd -xf sd_mrep.tar.zst', shell=True)
115
+ call('rm sd_mrep.tar.zst', shell=True)
116
+ os.chdir('/workspace/sd')
117
+ if not os.path.exists('stable-diffusion-webui'):
118
+ call('git clone -q --depth 1 --branch master https://github.com/AUTOMATIC1111/stable-diffusion-webui', shell=True)
119
+
120
+ else:
121
+ print('Installing/Updating the repo...')
122
+ os.chdir('/workspace')
123
+ if not os.path.exists('/workspace/sd/stablediffusiond'): #reset later
124
+ call('wget -q -O sd_mrep.tar.zst https://huggingface.co/TheLastBen/dependencies/resolve/main/sd_mrep.tar.zst', shell=True)
125
+ call('tar --zstd -xf sd_mrep.tar.zst', shell=True)
126
+ call('rm sd_mrep.tar.zst', shell=True)
127
+
128
+ os.chdir('/workspace/sd')
129
+ if not os.path.exists('stable-diffusion-webui'):
130
+ call('git clone -q --depth 1 --branch master https://github.com/AUTOMATIC1111/stable-diffusion-webui', shell=True)
131
+
132
+
133
+ os.chdir('/workspace/sd/stable-diffusion-webui/')
134
+ call('git reset --hard', shell=True)
135
+ print('')
136
+ call('git pull', shell=True)
137
+ os.chdir('/workspace')
138
+ clear_output()
139
+ done()
140
+
141
+
142
+
143
+ def mdl(Original_Model_Version, Path_to_MODEL, MODEL_LINK):
144
+
145
+ import gdown
146
+
147
+ src=getsrc(MODEL_LINK)
148
+
149
+ if not os.path.exists('/workspace/sd/stable-diffusion-webui/models/Stable-diffusion/SDv1-5.ckpt'):
150
+ call('ln -s /workspace/auto-models/* /workspace/sd/stable-diffusion-webui/models/Stable-diffusion', shell=True)
151
+
152
+ if Path_to_MODEL !='':
153
+ if os.path.exists(str(Path_to_MODEL)):
154
+ print('Using the custom model')
155
+ model=Path_to_MODEL
156
+ else:
157
+ print('Wrong path, check that the path to the model is correct')
158
+
159
+ elif MODEL_LINK !="":
160
+
161
+ if src=='civitai':
162
+ modelname=get_name(MODEL_LINK, False)
163
+ model=f'/workspace/sd/stable-diffusion-webui/models/Stable-diffusion/{modelname}'
164
+ if not os.path.exists(model):
165
+ dwn(MODEL_LINK, model, 'Downloading the custom model')
166
+ clear_output()
167
+ else:
168
+ print('Model already exists')
169
+ elif src=='gdrive':
170
+ modelname=get_name(MODEL_LINK, True)
171
+ model=f'/workspace/sd/stable-diffusion-webui/models/Stable-diffusion/{modelname}'
172
+ if not os.path.exists(model):
173
+ gdown.download(url=MODEL_LINK, output=model, quiet=False, fuzzy=True)
174
+ clear_output()
175
+ else:
176
+ print('Model already exists')
177
+ else:
178
+ modelname=os.path.basename(MODEL_LINK)
179
+ model=f'/workspace/sd/stable-diffusion-webui/models/Stable-diffusion/{modelname}'
180
+ if not os.path.exists(model):
181
+ gdown.download(url=MODEL_LINK, output=model, quiet=False, fuzzy=True)
182
+ clear_output()
183
+ else:
184
+ print('Model already exists')
185
+
186
+ if os.path.exists(model) and os.path.getsize(model) > 1810671599:
187
+ print('Model downloaded, using the custom model.')
188
+ else:
189
+ call('rm '+model, shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
190
+ print('Wrong link, check that the link is valid')
191
+
192
+ else:
193
+ if Original_Model_Version == "v1.5":
194
+ model="/workspace/sd/stable-diffusion-webui/models/Stable-diffusion/SDv1-5.ckpt"
195
+ print('Using the original V1.5 model')
196
+ elif Original_Model_Version == "v2-512":
197
+ model='/workspace/sd/stable-diffusion-webui/models/Stable-diffusion/SDv2-512.ckpt'
198
+ if not os.path.exists('/workspace/sd/stable-diffusion-webui/models/Stable-diffusion/SDv2-512.ckpt'):
199
+ print('Downloading the V2-512 model...')
200
+ call('gdown -O '+model+' https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-nonema-pruned.ckpt', shell=True)
201
+ clear_output()
202
+ print('Using the original V2-512 model')
203
+ elif Original_Model_Version == "v2-768":
204
+ model="/workspace/sd/stable-diffusion-webui/models/Stable-diffusion/SDv2-768.ckpt"
205
+ print('Using the original V2-768 model')
206
+ elif Original_Model_Version == "SDXL":
207
+ model="/workspace/sd/stable-diffusion-webui/models/Stable-diffusion/sd_xl_base_1.0.safetensors"
208
+ print('Using the original SDXL model')
209
+
210
+ else:
211
+ model="/workspace/sd/stable-diffusion-webui/models/Stable-diffusion"
212
+ print('Wrong model version, try again')
213
+ try:
214
+ model
215
+ except:
216
+ model="/workspace/sd/stable-diffusion-webui/models/Stable-diffusion"
217
+
218
+ return model
219
+
220
+
221
+
222
+ def loradwn(LoRA_LINK):
223
+
224
+ os.makedirs('/workspace/sd/stable-diffusion-webui/models/Lora', exist_ok=True)
225
+
226
+ src=getsrc(LoRA_LINK)
227
+
228
+ if src=='civitai':
229
+ modelname=get_name(LoRA_LINK, False)
230
+ loramodel=f'/workspace/sd/stable-diffusion-webui/models/Lora/{modelname}'
231
+ if not os.path.exists(loramodel):
232
+ dwn(LoRA_LINK, loramodel, 'Downloading the LoRA model')
233
+ clear_output()
234
+ else:
235
+ print('Model already exists')
236
+ elif src=='gdrive':
237
+ modelname=get_name(LoRA_LINK, True)
238
+ loramodel=f'/workspace/sd/stable-diffusion-webui/models/Lora/{modelname}'
239
+ if not os.path.exists(loramodel):
240
+ gdown.download(url=LoRA_LINK, output=loramodel, quiet=False, fuzzy=True)
241
+ clear_output()
242
+ else:
243
+ print('Model already exists')
244
+ else:
245
+ modelname=os.path.basename(LoRA_LINK)
246
+ loramodel=f'/workspace/sd/stable-diffusion-webui/models/Lora/{modelname}'
247
+ if not os.path.exists(loramodel):
248
+ gdown.download(url=LoRA_LINK, output=loramodel, quiet=False, fuzzy=True)
249
+ clear_output()
250
+ else:
251
+ print('Model already exists')
252
+
253
+ if os.path.exists(loramodel) :
254
+ print('LoRA downloaded')
255
+ else:
256
+ print('Wrong link, check that the link is valid')
257
+
258
+
259
+
260
+ def CNet(ControlNet_Model, ControlNet_v2_Model):
261
+
262
+ def download(url, model_dir):
263
+
264
+ filename = os.path.basename(urlparse(url).path)
265
+ pth = os.path.abspath(os.path.join(model_dir, filename))
266
+ if not os.path.exists(pth):
267
+ print('Downloading: '+os.path.basename(url))
268
+ download_url_to_file(url, pth, hash_prefix=None, progress=True)
269
+ else:
270
+ print(f"The model {filename} already exists")
271
+
272
+ wrngv1=False
273
+ os.chdir('/workspace/sd/stable-diffusion-webui/extensions')
274
+ if not os.path.exists("sd-webui-controlnet"):
275
+ call('git clone https://github.com/Mikubill/sd-webui-controlnet.git', shell=True)
276
+ os.chdir('/workspace')
277
+ else:
278
+ os.chdir('sd-webui-controlnet')
279
+ call('git reset --hard', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
280
+ call('git pull', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
281
+ os.chdir('/workspace')
282
+
283
+ mdldir="/workspace/sd/stable-diffusion-webui/extensions/sd-webui-controlnet/models"
284
+ for filename in os.listdir(mdldir):
285
+ if "_sd14v1" in filename:
286
+ renamed = re.sub("_sd14v1", "-fp16", filename)
287
+ os.rename(os.path.join(mdldir, filename), os.path.join(mdldir, renamed))
288
+
289
+ call('wget -q -O CN_models.txt https://github.com/TheLastBen/fast-stable-diffusion/raw/main/AUTOMATIC1111_files/CN_models.txt', shell=True)
290
+ call('wget -q -O CN_models_v2.txt https://github.com/TheLastBen/fast-stable-diffusion/raw/main/AUTOMATIC1111_files/CN_models_v2.txt', shell=True)
291
+
292
+ with open("CN_models.txt", 'r') as f:
293
+ mdllnk = f.read().splitlines()
294
+ with open("CN_models_v2.txt", 'r') as d:
295
+ mdllnk_v2 = d.read().splitlines()
296
+ call('rm CN_models.txt CN_models_v2.txt', shell=True)
297
+
298
+ cfgnames=[os.path.basename(url).split('.')[0]+'.yaml' for url in mdllnk_v2]
299
+ os.chdir('/workspace/sd/stable-diffusion-webui/extensions/sd-webui-controlnet/models')
300
+ for name in cfgnames:
301
+ run(['cp', 'cldm_v21.yaml', name])
302
+ os.chdir('/workspace')
303
+
304
+ if ControlNet_Model == "All" or ControlNet_Model == "all" :
305
+ for lnk in mdllnk:
306
+ download(lnk, mdldir)
307
+ clear_output()
308
+
309
+
310
+ elif ControlNet_Model == "15":
311
+ mdllnk=list(filter(lambda x: 't2i' in x, mdllnk))
312
+ for lnk in mdllnk:
313
+ download(lnk, mdldir)
314
+ clear_output()
315
+
316
+
317
+ elif ControlNet_Model.isdigit() and int(ControlNet_Model)-1<14 and int(ControlNet_Model)>0:
318
+ download(mdllnk[int(ControlNet_Model)-1], mdldir)
319
+ clear_output()
320
+
321
+ elif ControlNet_Model == "none":
322
+ pass
323
+ clear_output()
324
+
325
+ else:
326
+ print('Wrong ControlNet V1 choice, try again')
327
+ wrngv1=True
328
+
329
+
330
+ if ControlNet_v2_Model == "All" or ControlNet_v2_Model == "all" :
331
+ for lnk_v2 in mdllnk_v2:
332
+ download(lnk_v2, mdldir)
333
+ if not wrngv1:
334
+ clear_output()
335
+ done()
336
+
337
+ elif ControlNet_v2_Model.isdigit() and int(ControlNet_v2_Model)-1<5:
338
+ download(mdllnk_v2[int(ControlNet_v2_Model)-1], mdldir)
339
+ if not wrngv1:
340
+ clear_output()
341
+ done()
342
+
343
+ elif ControlNet_v2_Model == "none":
344
+ pass
345
+ if not wrngv1:
346
+ clear_output()
347
+ done()
348
+
349
+ else:
350
+ print('Wrong ControlNet V2 choice, try again')
351
+
352
+
353
+
354
+ def sd(User, Password, model):
355
+
356
+ import gradio
357
+
358
+ gradio.close_all()
359
+
360
+ auth=f"--gradio-auth {User}:{Password}"
361
+ if User =="" or Password=="":
362
+ auth=""
363
+
364
+ call('wget -q -O /usr/local/lib/python3.10/dist-packages/gradio/blocks.py https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/main/AUTOMATIC1111_files/blocks.py', shell=True)
365
+
366
+ os.chdir('/workspace/sd/stable-diffusion-webui/modules')
367
+
368
+ call("sed -i 's@possible_sd_paths =.*@possible_sd_paths = [\"/workspace/sd/stablediffusion\"]@' /workspace/sd/stable-diffusion-webui/modules/paths.py", shell=True)
369
+ call("sed -i 's@\.\.\/@src/@g' /workspace/sd/stable-diffusion-webui/modules/paths.py", shell=True)
370
+ call("sed -i 's@src\/generative-models@generative-models@g' /workspace/sd/stable-diffusion-webui/modules/paths.py", shell=True)
371
+
372
+ call("sed -i 's@\[\"sd_model_checkpoint\"\]@\[\"sd_model_checkpoint\", \"sd_vae\", \"CLIP_stop_at_last_layers\", \"inpainting_mask_weight\", \"initial_noise_multiplier\"\]@g' /workspace/sd/stable-diffusion-webui/modules/shared.py", shell=True)
373
+
374
+ call("sed -i 's@print(\"No module.*@@' /workspace/sd/stablediffusion/ldm/modules/diffusionmodules/model.py", shell=True)
375
+ os.chdir('/workspace/sd/stable-diffusion-webui')
376
+ clear_output()
377
+
378
+ podid=os.environ.get('RUNPOD_POD_ID')
379
+ localurl=f"{podid}-3001.proxy.runpod.net"
380
+
381
+ for line in fileinput.input('/usr/local/lib/python3.10/dist-packages/gradio/blocks.py', inplace=True):
382
+ if line.strip().startswith('self.server_name ='):
383
+ line = f' self.server_name = "{localurl}"\n'
384
+ if line.strip().startswith('self.protocol = "https"'):
385
+ line = ' self.protocol = "https"\n'
386
+ if line.strip().startswith('if self.local_url.startswith("https") or self.is_colab'):
387
+ line = ''
388
+ if line.strip().startswith('else "http"'):
389
+ line = ''
390
+ sys.stdout.write(line)
391
+
392
+ if model=="":
393
+ mdlpth=""
394
+ else:
395
+ if os.path.isfile(model):
396
+ mdlpth="--ckpt "+model
397
+ else:
398
+ mdlpth="--ckpt-dir "+model
399
+
400
+ configf="--disable-console-progressbars --no-half-vae --disable-safe-unpickle --api --no-download-sd-model --opt-sdp-attention --enable-insecure-extension-access --skip-version-check --listen --port 3000 "+auth+" "+mdlpth
401
+
402
+ return configf
403
+
404
+
405
+
406
+ def save(Huggingface_Write_token):
407
+
408
+ from slugify import slugify
409
+ from huggingface_hub import HfApi, CommitOperationAdd, create_repo
410
+
411
+ if Huggingface_Write_token=="":
412
+ print('A huggingface write token is required')
413
+
414
+ else:
415
+ os.chdir('/workspace')
416
+
417
+ if os.path.exists('sd'):
418
+
419
+ call('tar --exclude="stable-diffusion-webui/models/*/*" --exclude="sd-webui-controlnet/models/*" --zstd -cf sd_backup_rnpd.tar.zst sd', shell=True)
420
+ api = HfApi()
421
+ username = api.whoami(token=Huggingface_Write_token)["name"]
422
+
423
+ repo_id = f"{username}/{slugify('fast-stable-diffusion')}"
424
+
425
+ print("Backing up...")
426
+
427
+ operations = [CommitOperationAdd(path_in_repo="sd_backup_rnpd.tar.zst", path_or_fileobj="/workspace/sd_backup_rnpd.tar.zst")]
428
+
429
+ create_repo(repo_id,private=True, token=Huggingface_Write_token, exist_ok=True, repo_type="dataset")
430
+
431
+ api.create_commit(
432
+ repo_id=repo_id,
433
+ repo_type="dataset",
434
+ operations=operations,
435
+ commit_message="SD folder Backup",
436
+ token=Huggingface_Write_token
437
+ )
438
+
439
+ call('rm sd_backup_rnpd.tar.zst', shell=True)
440
+ clear_output()
441
+
442
+ done()
443
+
444
+ else:
445
+ print('Nothing to backup')
446
+
447
+
448
+
449
+
450
+ def getsrc(url):
451
+
452
+ parsed_url = urlparse(url)
453
+
454
+ if parsed_url.netloc == 'civitai.com':
455
+ src='civitai'
456
+ elif parsed_url.netloc == 'drive.google.com':
457
+ src='gdrive'
458
+ elif parsed_url.netloc == 'huggingface.co':
459
+ src='huggingface'
460
+ else:
461
+ src='others'
462
+ return src
463
+
464
+
465
+
466
+ def get_name(url, gdrive):
467
+
468
+ from gdown.download import get_url_from_gdrive_confirmation
469
+
470
+ if not gdrive:
471
+ response = requests.get(url, allow_redirects=False)
472
+ if "Location" in response.headers:
473
+ redirected_url = response.headers["Location"]
474
+ quer = parse_qs(urlparse(redirected_url).query)
475
+ if "response-content-disposition" in quer:
476
+ disp_val = quer["response-content-disposition"][0].split(";")
477
+ for vals in disp_val:
478
+ if vals.strip().startswith("filename="):
479
+ filenm=unquote(vals.split("=", 1)[1].strip())
480
+ return filenm.replace("\"","")
481
+ else:
482
+ headers = {"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36"}
483
+ lnk="https://drive.google.com/uc?id={id}&export=download".format(id=url[url.find("/d/")+3:url.find("/view")])
484
+ res = requests.session().get(lnk, headers=headers, stream=True, verify=True)
485
+ res = requests.session().get(get_url_from_gdrive_confirmation(res.text), headers=headers, stream=True, verify=True)
486
+ content_disposition = six.moves.urllib_parse.unquote(res.headers["Content-Disposition"])
487
+ filenm = re.search(r"filename\*=UTF-8''(.*)", content_disposition).groups()[0].replace(os.path.sep, "_")
488
+ return filenm
489
+
490
+
491
+
492
+
493
+ def done():
494
+ done = widgets.Button(
495
+ description='Done!',
496
+ disabled=True,
497
+ button_style='success',
498
+ tooltip='',
499
+ icon='check'
500
+ )
501
+ display(done)
sdxllorarunpod.py ADDED
@@ -0,0 +1,1131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from IPython.display import clear_output
2
+ from subprocess import call, getoutput, Popen
3
+ from IPython.display import display
4
+ import ipywidgets as widgets
5
+ import io
6
+ from PIL import Image, ImageDraw, ImageOps
7
+ import fileinput
8
+ import time
9
+ import os
10
+ from os import listdir
11
+ from os.path import isfile
12
+ import random
13
+ import sys
14
+ from io import BytesIO
15
+ import requests
16
+ from collections import defaultdict
17
+ from math import log, sqrt
18
+ import numpy as np
19
+ import sys
20
+ import fileinput
21
+ from subprocess import check_output
22
+ import six
23
+ import base64
24
+ import re
25
+
26
+ from urllib.parse import urlparse, parse_qs, unquote
27
+ import urllib.request
28
+ from urllib.request import urlopen, Request
29
+
30
+ import tempfile
31
+ from tqdm import tqdm
32
+
33
+
34
+
35
+
36
+ def Deps(force_reinstall):
37
+
38
+ if not force_reinstall and os.path.exists('/usr/local/lib/python3.10/dist-packages/safetensors'):
39
+ ntbks()
40
+ call('pip install --root-user-action=ignore --disable-pip-version-check -qq diffusers==0.18.1', shell=True, stdout=open('/dev/null', 'w'))
41
+ print('Modules and notebooks updated, dependencies already installed')
42
+ os.environ['TORCH_HOME'] = '/workspace/cache/torch'
43
+ os.environ['PYTHONWARNINGS'] = 'ignore'
44
+ else:
45
+ call('pip install --root-user-action=ignore --disable-pip-version-check --no-deps -qq gdown PyWavelets numpy==1.23.5 accelerate==0.12.0 --force-reinstall', shell=True, stdout=open('/dev/null', 'w'))
46
+ ntbks()
47
+ if os.path.exists('deps'):
48
+ call("rm -r deps", shell=True)
49
+ if os.path.exists('diffusers'):
50
+ call("rm -r diffusers", shell=True)
51
+ call('mkdir deps', shell=True)
52
+ if not os.path.exists('cache'):
53
+ call('mkdir cache', shell=True)
54
+ os.chdir('deps')
55
+ dwn("https://huggingface.co/TheLastBen/dependencies/resolve/main/rnpddeps-t2.tar.zst", "/workspace/deps/rnpddeps-t2.tar.zst", "Installing dependencies")
56
+ call('tar -C / --zstd -xf rnpddeps-t2.tar.zst', shell=True, stdout=open('/dev/null', 'w'))
57
+ call("sed -i 's@~/.cache@/workspace/cache@' /usr/local/lib/python3.10/dist-packages/transformers/utils/hub.py", shell=True)
58
+ os.chdir('/workspace')
59
+ call('pip install --root-user-action=ignore --disable-pip-version-check -qq diffusers==0.18.1', shell=True, stdout=open('/dev/null', 'w'))
60
+ call("git clone --depth 1 -q --branch main https://github.com/TheLastBen/diffusers", shell=True, stdout=open('/dev/null', 'w'))
61
+ #call('pip install --root-user-action=ignore --disable-pip-version-check -qq gradio==3.41.0', shell=True, stdout=open('/dev/null', 'w'))
62
+ call("rm -r deps", shell=True)
63
+ os.chdir('/workspace')
64
+ os.environ['TORCH_HOME'] = '/workspace/cache/torch'
65
+ os.environ['PYTHONWARNINGS'] = 'ignore'
66
+ call("sed -i 's@text = _formatwarnmsg(msg)@text =\"\"@g' /usr/lib/python3.10/warnings.py", shell=True)
67
+ clear_output()
68
+
69
+ done()
70
+
71
+
72
+ def dwn(url, dst, msg):
73
+ file_size = None
74
+ req = Request(url, headers={"User-Agent": "torch.hub"})
75
+ u = urlopen(req)
76
+ meta = u.info()
77
+ if hasattr(meta, 'getheaders'):
78
+ content_length = meta.getheaders("Content-Length")
79
+ else:
80
+ content_length = meta.get_all("Content-Length")
81
+ if content_length is not None and len(content_length) > 0:
82
+ file_size = int(content_length[0])
83
+
84
+ with tqdm(total=file_size, disable=False, mininterval=0.5,
85
+ bar_format=msg+' |{bar:20}| {percentage:3.0f}%') as pbar:
86
+ with open(dst, "wb") as f:
87
+ while True:
88
+ buffer = u.read(8192)
89
+ if len(buffer) == 0:
90
+ break
91
+ f.write(buffer)
92
+ pbar.update(len(buffer))
93
+ f.close()
94
+
95
+
96
+ def ntbks():
97
+
98
+ os.chdir('/workspace')
99
+ if not os.path.exists('Latest_Notebooks'):
100
+ call('mkdir Latest_Notebooks', shell=True)
101
+ else:
102
+ call('rm -r Latest_Notebooks', shell=True)
103
+ call('mkdir Latest_Notebooks', shell=True)
104
+ os.chdir('/workspace/Latest_Notebooks')
105
+ call('wget -q -i https://huggingface.co/datasets/TheLastBen/RNPD/raw/main/Notebooks.txt', shell=True)
106
+ call('rm Notebooks.txt', shell=True)
107
+ os.chdir('/workspace')
108
+
109
+ def done():
110
+ done = widgets.Button(
111
+ description='Done!',
112
+ disabled=True,
113
+ button_style='success',
114
+ tooltip='',
115
+ icon='check'
116
+ )
117
+ display(done)
118
+
119
+
120
+
121
+ def mdlvxl():
122
+
123
+ os.chdir('/workspace')
124
+
125
+ if os.path.exists('stable-diffusion-XL') and not os.path.exists('/workspace/stable-diffusion-XL/unet/diffusion_pytorch_model.safetensors'):
126
+ call('rm -r stable-diffusion-XL', shell=True)
127
+ if not os.path.exists('stable-diffusion-XL'):
128
+ print('Downloading SDXL model...')
129
+ call('mkdir stable-diffusion-XL', shell=True)
130
+ os.chdir('stable-diffusion-XL')
131
+ call('git init', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
132
+ call('git lfs install --system --skip-repo', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
133
+ call('git remote add -f origin https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
134
+ call('git config core.sparsecheckout true', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
135
+ call('echo -e "\nscheduler\ntext_encoder\ntext_encoder_2\ntokenizer\ntokenizer_2\nunet\nvae\nfeature_extractor\nmodel_index.json\n!*.safetensors\n!*.bin\n!*.onnx*\n!*.xml" > .git/info/sparse-checkout', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
136
+ call('git pull origin main', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
137
+ dwn('https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/text_encoder/model.safetensors', 'text_encoder/model.safetensors', '1/4')
138
+ dwn('https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/text_encoder_2/model.safetensors', 'text_encoder_2/model.safetensors', '2/4')
139
+ dwn('https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/vae/diffusion_pytorch_model.safetensors', 'vae/diffusion_pytorch_model.safetensors', '3/4')
140
+ dwn('https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/unet/diffusion_pytorch_model.safetensors', 'unet/diffusion_pytorch_model.safetensors', '4/4')
141
+ call('rm -r .git', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
142
+ os.chdir('/workspace')
143
+ clear_output()
144
+ while not os.path.exists('/workspace/stable-diffusion-XL/unet/diffusion_pytorch_model.safetensors'):
145
+ print('Invalid HF token, make sure you have access to the model')
146
+ time.sleep(8)
147
+ if os.path.exists('/workspace/stable-diffusion-XL/unet/diffusion_pytorch_model.safetensors'):
148
+ print('Using SDXL model')
149
+ else:
150
+ print('Using SDXL model')
151
+
152
+ call("sed -i 's@\"force_upcast.*@@' /workspace/stable-diffusion-XL/vae/config.json", shell=True)
153
+
154
+
155
+
156
+ def downloadmodel_hfxl(Path_to_HuggingFace):
157
+
158
+ os.chdir('/workspace')
159
+ if os.path.exists('stable-diffusion-custom'):
160
+ call("rm -r stable-diffusion-custom", shell=True)
161
+ clear_output()
162
+
163
+ if os.path.exists('Fast-Dreambooth/token.txt'):
164
+ with open("Fast-Dreambooth/token.txt") as f:
165
+ token = f.read()
166
+ authe=f'https://USER:{token}@'
167
+ else:
168
+ authe="https://"
169
+
170
+ clear_output()
171
+ call("mkdir stable-diffusion-custom", shell=True)
172
+ os.chdir("stable-diffusion-custom")
173
+ call("git init", shell=True)
174
+ call("git lfs install --system --skip-repo", shell=True)
175
+ call('git remote add -f origin '+authe+'huggingface.co/'+Path_to_HuggingFace, shell=True)
176
+ call("git config core.sparsecheckout true", shell=True)
177
+ call('echo -e "\nscheduler\ntext_encoder\ntokenizer\nunet\nvae\nfeature_extractor\nmodel_index.json\n!*.fp16.safetensors" > .git/info/sparse-checkout', shell=True)
178
+ call("git pull origin main", shell=True)
179
+ if os.path.exists('unet/diffusion_pytorch_model.safetensors'):
180
+ call("rm -r .git", shell=True)
181
+ os.chdir('/workspace')
182
+ clear_output()
183
+ done()
184
+ while not os.path.exists('/workspace/stable-diffusion-custom/unet/diffusion_pytorch_model.safetensors'):
185
+ print('Check the link you provided')
186
+ os.chdir('/workspace')
187
+ time.sleep(5)
188
+
189
+
190
+
191
+ def downloadmodel_link_xl(MODEL_LINK):
192
+
193
+ import wget
194
+ import gdown
195
+ from gdown.download import get_url_from_gdrive_confirmation
196
+
197
+ def getsrc(url):
198
+ parsed_url = urlparse(url)
199
+ if parsed_url.netloc == 'civitai.com':
200
+ src='civitai'
201
+ elif parsed_url.netloc == 'drive.google.com':
202
+ src='gdrive'
203
+ elif parsed_url.netloc == 'huggingface.co':
204
+ src='huggingface'
205
+ else:
206
+ src='others'
207
+ return src
208
+
209
+ src=getsrc(MODEL_LINK)
210
+
211
+ def get_name(url, gdrive):
212
+ if not gdrive:
213
+ response = requests.get(url, allow_redirects=False)
214
+ if "Location" in response.headers:
215
+ redirected_url = response.headers["Location"]
216
+ quer = parse_qs(urlparse(redirected_url).query)
217
+ if "response-content-disposition" in quer:
218
+ disp_val = quer["response-content-disposition"][0].split(";")
219
+ for vals in disp_val:
220
+ if vals.strip().startswith("filename="):
221
+ filenm=unquote(vals.split("=", 1)[1].strip())
222
+ return filenm.replace("\"","")
223
+ else:
224
+ headers = {"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36"}
225
+ lnk="https://drive.google.com/uc?id={id}&export=download".format(id=url[url.find("/d/")+3:url.find("/view")])
226
+ res = requests.session().get(lnk, headers=headers, stream=True, verify=True)
227
+ res = requests.session().get(get_url_from_gdrive_confirmation(res.text), headers=headers, stream=True, verify=True)
228
+ content_disposition = six.moves.urllib_parse.unquote(res.headers["Content-Disposition"])
229
+ filenm = re.search(r"filename\*=UTF-8''(.*)", content_disposition).groups()[0].replace(os.path.sep, "_")
230
+ return filenm
231
+
232
+ if src=='civitai':
233
+ modelname=get_name(MODEL_LINK, False)
234
+ elif src=='gdrive':
235
+ modelname=get_name(MODEL_LINK, True)
236
+ else:
237
+ modelname=os.path.basename(MODEL_LINK)
238
+
239
+
240
+ os.chdir('/workspace')
241
+ if src=='huggingface':
242
+ dwn(MODEL_LINK, modelname,'Downloading the Model')
243
+ else:
244
+ call("gdown --fuzzy " +MODEL_LINK+ " -O "+modelname, shell=True)
245
+
246
+ if os.path.exists(modelname):
247
+ if os.path.getsize(modelname) > 1810671599:
248
+
249
+ print('Converting to diffusers...')
250
+ call('python /workspace/diffusers/scripts/convert_original_stable_diffusion_to_diffusers.py --checkpoint_path '+modelname+' --dump_path stable-diffusion-custom --from_safetensors', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
251
+
252
+ if os.path.exists('stable-diffusion-custom/unet/diffusion_pytorch_model.bin'):
253
+ os.chdir('/workspace')
254
+ clear_output()
255
+ done()
256
+ else:
257
+ while not os.path.exists('stable-diffusion-custom/unet/diffusion_pytorch_model.bin'):
258
+ print('Conversion error')
259
+ os.chdir('/workspace')
260
+ time.sleep(5)
261
+ else:
262
+ while os.path.getsize(modelname) < 1810671599:
263
+ print('Wrong link, check that the link is valid')
264
+ os.chdir('/workspace')
265
+ time.sleep(5)
266
+
267
+
268
+
269
+ def downloadmodel_path_xl(MODEL_PATH):
270
+
271
+ import wget
272
+ os.chdir('/workspace')
273
+ clear_output()
274
+ if os.path.exists(str(MODEL_PATH)):
275
+
276
+ print('Converting to diffusers...')
277
+ call('python /workspace/diffusers/scripts/convert_original_stable_diffusion_to_diffusers.py --checkpoint_path '+MODEL_PATH+' --dump_path stable-diffusion-custom --from_safetensors', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
278
+
279
+ if os.path.exists('stable-diffusion-custom/unet/diffusion_pytorch_model.bin'):
280
+ clear_output()
281
+ done()
282
+ while not os.path.exists('stable-diffusion-custom/unet/diffusion_pytorch_model.bin'):
283
+ print('Conversion error')
284
+ os.chdir('/workspace')
285
+ time.sleep(5)
286
+ else:
287
+ while not os.path.exists(str(MODEL_PATH)):
288
+ print('Wrong path, use the file explorer to copy the path')
289
+ os.chdir('/workspace')
290
+ time.sleep(5)
291
+
292
+
293
+
294
+
295
+ def dls_xlf(Path_to_HuggingFace, MODEL_PATH, MODEL_LINK):
296
+
297
+ os.chdir('/workspace')
298
+
299
+ if Path_to_HuggingFace != "":
300
+ downloadmodel_hfxl(Path_to_HuggingFace)
301
+ MODEL_NAMExl="/workspace/stable-diffusion-custom"
302
+
303
+ elif MODEL_PATH !="":
304
+
305
+ downloadmodel_path_xl(MODEL_PATH)
306
+ MODEL_NAMExl="/workspace/stable-diffusion-custom"
307
+
308
+ elif MODEL_LINK !="":
309
+
310
+ downloadmodel_link_xl(MODEL_LINK)
311
+ MODEL_NAMExl="/workspace/stable-diffusion-custom"
312
+
313
+ else:
314
+ mdlvxl()
315
+ MODEL_NAMExl="/workspace/stable-diffusion-XL"
316
+
317
+ return MODEL_NAMExl
318
+
319
+
320
+
321
+ def sess_xl(Session_Name, MODEL_NAMExl):
322
+ import gdown
323
+ import wget
324
+ os.chdir('/workspace')
325
+ PT=""
326
+
327
+ while Session_Name=="":
328
+ print('Input the Session Name:')
329
+ Session_Name=input("")
330
+ Session_Name=Session_Name.replace(" ","_")
331
+
332
+ WORKSPACE='/workspace/Fast-Dreambooth'
333
+
334
+ INSTANCE_NAME=Session_Name
335
+ OUTPUT_DIR="/workspace/models/"+Session_Name
336
+ SESSION_DIR=WORKSPACE+"/Sessions/"+Session_Name
337
+ INSTANCE_DIR=SESSION_DIR+"/instance_images"
338
+ CAPTIONS_DIR=SESSION_DIR+'/captions'
339
+ MDLPTH=str(SESSION_DIR+"/"+Session_Name+'.safetensors')
340
+
341
+
342
+ if os.path.exists(str(SESSION_DIR)) and not os.path.exists(MDLPTH):
343
+ print('Loading session with no previous LoRa model')
344
+ if MODEL_NAMExl=="":
345
+ print('No model found, use the "Model Download" cell to download a model.')
346
+ else:
347
+ print('Session Loaded, proceed')
348
+
349
+ elif not os.path.exists(str(SESSION_DIR)):
350
+ call('mkdir -p '+INSTANCE_DIR, shell=True)
351
+ print('Creating session...')
352
+ if MODEL_NAMExl=="":
353
+ print('No model found, use the "Model Download" cell to download a model.')
354
+ else:
355
+ print('Session created, proceed to uploading instance images')
356
+ if MODEL_NAMExl=="":
357
+ print('No model found, use the "Model Download" cell to download a model.')
358
+
359
+ else:
360
+ print('Session Loaded, proceed')
361
+
362
+
363
+ return WORKSPACE, Session_Name, INSTANCE_NAME, OUTPUT_DIR, SESSION_DIR, INSTANCE_DIR, CAPTIONS_DIR, MDLPTH, MODEL_NAMExl
364
+
365
+
366
+
367
+ def uplder(Remove_existing_instance_images, Crop_images, Crop_size, IMAGES_FOLDER_OPTIONAL, INSTANCE_DIR, CAPTIONS_DIR):
368
+
369
+ if os.path.exists(INSTANCE_DIR+"/.ipynb_checkpoints"):
370
+ call('rm -r '+INSTANCE_DIR+'/.ipynb_checkpoints', shell=True)
371
+
372
+ uploader = widgets.FileUpload(description="Choose images",accept='image/*, .txt', multiple=True)
373
+ Upload = widgets.Button(
374
+ description='Upload',
375
+ disabled=False,
376
+ button_style='info',
377
+ tooltip='Click to upload the chosen instance images',
378
+ icon=''
379
+ )
380
+
381
+
382
+ def up(Upload):
383
+ with out:
384
+ uploader.close()
385
+ Upload.close()
386
+ upld(Remove_existing_instance_images, Crop_images, Crop_size, IMAGES_FOLDER_OPTIONAL, INSTANCE_DIR, CAPTIONS_DIR, uploader)
387
+ done()
388
+ out=widgets.Output()
389
+
390
+ if IMAGES_FOLDER_OPTIONAL=="":
391
+ Upload.on_click(up)
392
+ display(uploader, Upload, out)
393
+ else:
394
+ upld(Remove_existing_instance_images, Crop_images, Crop_size, IMAGES_FOLDER_OPTIONAL, INSTANCE_DIR, CAPTIONS_DIR, uploader)
395
+ done()
396
+
397
+
398
+
399
+ def upld(Remove_existing_instance_images, Crop_images, Crop_size, IMAGES_FOLDER_OPTIONAL, INSTANCE_DIR, CAPTIONS_DIR, uploader):
400
+
401
+ from tqdm import tqdm
402
+ if Remove_existing_instance_images:
403
+ if os.path.exists(str(INSTANCE_DIR)):
404
+ call("rm -r " +INSTANCE_DIR, shell=True)
405
+ if os.path.exists(str(CAPTIONS_DIR)):
406
+ call("rm -r " +CAPTIONS_DIR, shell=True)
407
+
408
+
409
+ if not os.path.exists(str(INSTANCE_DIR)):
410
+ call("mkdir -p " +INSTANCE_DIR, shell=True)
411
+ if not os.path.exists(str(CAPTIONS_DIR)):
412
+ call("mkdir -p " +CAPTIONS_DIR, shell=True)
413
+
414
+
415
+ if IMAGES_FOLDER_OPTIONAL !="":
416
+ if os.path.exists(IMAGES_FOLDER_OPTIONAL+"/.ipynb_checkpoints"):
417
+ call('rm -r '+IMAGES_FOLDER_OPTIONAL+'/.ipynb_checkpoints', shell=True)
418
+
419
+ if any(file.endswith('.{}'.format('txt')) for file in os.listdir(IMAGES_FOLDER_OPTIONAL)):
420
+ call('mv '+IMAGES_FOLDER_OPTIONAL+'/*.txt '+CAPTIONS_DIR, shell=True)
421
+ if Crop_images:
422
+ os.chdir(str(IMAGES_FOLDER_OPTIONAL))
423
+ call('find . -name "* *" -type f | rename ' "'s/ /-/g'", shell=True)
424
+ os.chdir('/workspace')
425
+ for filename in tqdm(os.listdir(IMAGES_FOLDER_OPTIONAL), bar_format=' |{bar:15}| {n_fmt}/{total_fmt} Uploaded'):
426
+ extension = filename.split(".")[-1]
427
+ identifier=filename.split(".")[0]
428
+ new_path_with_file = os.path.join(INSTANCE_DIR, filename)
429
+ file = Image.open(IMAGES_FOLDER_OPTIONAL+"/"+filename)
430
+ file=file.convert("RGB")
431
+ file=ImageOps.exif_transpose(file)
432
+ width, height = file.size
433
+ if file.size !=(Crop_size, Crop_size):
434
+ image=crop_image(file, Crop_size)
435
+ if extension.upper()=="JPG" or extension.upper()=="jpg":
436
+ image[0].save(new_path_with_file, format="JPEG", quality = 100)
437
+ else:
438
+ image[0].save(new_path_with_file, format=extension.upper())
439
+
440
+ else:
441
+ call("cp \'"+IMAGES_FOLDER_OPTIONAL+"/"+filename+"\' "+INSTANCE_DIR, shell=True)
442
+
443
+ else:
444
+ for filename in tqdm(os.listdir(IMAGES_FOLDER_OPTIONAL), bar_format=' |{bar:15}| {n_fmt}/{total_fmt} Uploaded'):
445
+ call("cp -r " +IMAGES_FOLDER_OPTIONAL+"/. " +INSTANCE_DIR, shell=True)
446
+
447
+ elif IMAGES_FOLDER_OPTIONAL =="":
448
+ up=""
449
+ for file in uploader.value:
450
+ filename = file['name']
451
+ if filename.split(".")[-1]=="txt":
452
+ with open(CAPTIONS_DIR+'/'+filename, 'w') as f:
453
+ f.write(bytes(file['content']).decode())
454
+ up=[file for file in uploader.value if not file['name'].endswith('.txt')]
455
+ if Crop_images:
456
+ for file in tqdm(up, bar_format=' |{bar:15}| {n_fmt}/{total_fmt} Uploaded'):
457
+ filename = file['name']
458
+ img = Image.open(io.BytesIO(file['content']))
459
+ img=img.convert("RGB")
460
+ img=ImageOps.exif_transpose(img)
461
+ extension = filename.split(".")[-1]
462
+ identifier=filename.split(".")[0]
463
+
464
+ if extension.upper()=="JPG" or extension.upper()=="jpg":
465
+ img.save(INSTANCE_DIR+"/"+filename, format="JPEG", quality = 100)
466
+ else:
467
+ img.save(INSTANCE_DIR+"/"+filename, format=extension.upper())
468
+
469
+ new_path_with_file = os.path.join(INSTANCE_DIR, filename)
470
+ file = Image.open(new_path_with_file)
471
+ width, height = file.size
472
+ if file.size !=(Crop_size, Crop_size):
473
+ image=crop_image(file, Crop_size)
474
+ if extension.upper()=="JPG" or extension.upper()=="jpg":
475
+ image[0].save(new_path_with_file, format="JPEG", quality = 100)
476
+ else:
477
+ image[0].save(new_path_with_file, format=extension.upper())
478
+
479
+ else:
480
+ for file in tqdm(uploader.value, bar_format=' |{bar:15}| {n_fmt}/{total_fmt} Uploaded'):
481
+ filename = file['name']
482
+ img = Image.open(io.BytesIO(file['content']))
483
+ img=img.convert("RGB")
484
+ extension = filename.split(".")[-1]
485
+ identifier=filename.split(".")[0]
486
+
487
+ if extension.upper()=="JPG" or extension.upper()=="jpg":
488
+ img.save(INSTANCE_DIR+"/"+filename, format="JPEG", quality = 100)
489
+ else:
490
+ img.save(INSTANCE_DIR+"/"+filename, format=extension.upper())
491
+
492
+
493
+ os.chdir(INSTANCE_DIR)
494
+ call('find . -name "* *" -type f | rename ' "'s/ /-/g'", shell=True)
495
+ os.chdir(CAPTIONS_DIR)
496
+ call('find . -name "* *" -type f | rename ' "'s/ /-/g'", shell=True)
497
+ os.chdir('/workspace')
498
+
499
+
500
+
501
+
502
+ def caption(CAPTIONS_DIR, INSTANCE_DIR):
503
+
504
+ paths=""
505
+ out=""
506
+ widgets_l=""
507
+ clear_output()
508
+ def Caption(path):
509
+ if path!="Select an instance image to caption":
510
+
511
+ name = os.path.splitext(os.path.basename(path))[0]
512
+ ext=os.path.splitext(os.path.basename(path))[-1][1:]
513
+ if ext=="jpg" or "JPG":
514
+ ext="JPEG"
515
+
516
+ if os.path.exists(CAPTIONS_DIR+"/"+name + '.txt'):
517
+ with open(CAPTIONS_DIR+"/"+name + '.txt', 'r') as f:
518
+ text = f.read()
519
+ else:
520
+ with open(CAPTIONS_DIR+"/"+name + '.txt', 'w') as f:
521
+ f.write("")
522
+ with open(CAPTIONS_DIR+"/"+name + '.txt', 'r') as f:
523
+ text = f.read()
524
+
525
+ img=Image.open(os.path.join(INSTANCE_DIR,path))
526
+ img=img.convert("RGB")
527
+ img=img.resize((420, 420))
528
+ image_bytes = BytesIO()
529
+ img.save(image_bytes, format=ext, qualiy=10)
530
+ image_bytes.seek(0)
531
+ image_data = image_bytes.read()
532
+ img= image_data
533
+ image = widgets.Image(
534
+ value=img,
535
+ width=420,
536
+ height=420
537
+ )
538
+ text_area = widgets.Textarea(value=text, description='', disabled=False, layout={'width': '300px', 'height': '120px'})
539
+
540
+
541
+ def update_text(text):
542
+ with open(CAPTIONS_DIR+"/"+name + '.txt', 'w') as f:
543
+ f.write(text)
544
+
545
+ button = widgets.Button(description='Save', button_style='success')
546
+ button.on_click(lambda b: update_text(text_area.value))
547
+
548
+ return widgets.VBox([widgets.HBox([image, text_area, button])])
549
+
550
+
551
+ paths = os.listdir(INSTANCE_DIR)
552
+ widgets_l = widgets.Select(options=["Select an instance image to caption"]+paths, rows=25)
553
+
554
+
555
+ out = widgets.Output()
556
+
557
+ def click(change):
558
+ with out:
559
+ out.clear_output()
560
+ display(Caption(change.new))
561
+
562
+ widgets_l.observe(click, names='value')
563
+ display(widgets.HBox([widgets_l, out]))
564
+
565
+
566
+
567
+ def dbtrainxl(Unet_Training_Epochs, Text_Encoder_Training_Epochs, Unet_Learning_Rate, Text_Encoder_Learning_Rate, dim, Offset_Noise, Resolution, MODEL_NAME, SESSION_DIR, INSTANCE_DIR, CAPTIONS_DIR, External_Captions, INSTANCE_NAME, Session_Name, OUTPUT_DIR, ofstnselvl, Save_VRAM):
568
+
569
+
570
+ if os.path.exists(INSTANCE_DIR+"/.ipynb_checkpoints"):
571
+ call('rm -r '+INSTANCE_DIR+'/.ipynb_checkpoints', shell=True)
572
+ if os.path.exists(CAPTIONS_DIR+"/.ipynb_checkpoints"):
573
+ call('rm -r '+CAPTIONS_DIR+'/.ipynb_checkpoints', shell=True)
574
+
575
+
576
+ Seed=random.randint(1, 999999)
577
+
578
+ ofstnse=""
579
+ if Offset_Noise:
580
+ ofstnse="--offset_noise"
581
+
582
+ GC=''
583
+ if Save_VRAM:
584
+ GC='--gradient_checkpointing'
585
+
586
+ extrnlcptn=""
587
+ if External_Captions:
588
+ extrnlcptn="--external_captions"
589
+
590
+ precision="fp16"
591
+
592
+
593
+
594
+ def train_only_text(SESSION_DIR, MODEL_NAME, INSTANCE_DIR, OUTPUT_DIR, Seed, Resolution, ofstnse, extrnlcptn, precision, Training_Epochs):
595
+ print('Training the Text Encoder...')
596
+ call('accelerate launch /workspace/diffusers/examples/dreambooth/train_dreambooth_rnpd_sdxl_TI.py \
597
+ '+ofstnse+' \
598
+ '+extrnlcptn+' \
599
+ --dim='+str(dim)+' \
600
+ --ofstnselvl='+str(ofstnselvl)+' \
601
+ --image_captions_filename \
602
+ --Session_dir='+SESSION_DIR+' \
603
+ --pretrained_model_name_or_path='+MODEL_NAME+' \
604
+ --instance_data_dir='+INSTANCE_DIR+' \
605
+ --output_dir='+OUTPUT_DIR+' \
606
+ --captions_dir='+CAPTIONS_DIR+' \
607
+ --seed='+str(Seed)+' \
608
+ --resolution='+str(Resolution)+' \
609
+ --mixed_precision='+str(precision)+' \
610
+ --train_batch_size=1 \
611
+ --gradient_accumulation_steps=1 '+GC+ ' \
612
+ --use_8bit_adam \
613
+ --learning_rate='+str(Text_Encoder_Learning_Rate)+' \
614
+ --lr_scheduler="cosine" \
615
+ --lr_warmup_steps=0 \
616
+ --num_train_epochs='+str(Training_Epochs), shell=True)
617
+
618
+
619
+
620
+ def train_only_unet(SESSION_DIR, MODEL_NAME, INSTANCE_DIR, OUTPUT_DIR, Seed, Resolution, ofstnse, extrnlcptn, precision, Training_Epochs):
621
+ print('Training the UNet...')
622
+ call('accelerate launch /workspace/diffusers/examples/dreambooth/train_dreambooth_rnpd_sdxl_lora.py \
623
+ '+ofstnse+' \
624
+ '+extrnlcptn+' \
625
+ --dim='+str(dim)+' \
626
+ --ofstnselvl='+str(ofstnselvl)+' \
627
+ --image_captions_filename \
628
+ --Session_dir='+SESSION_DIR+' \
629
+ --pretrained_model_name_or_path='+MODEL_NAME+' \
630
+ --instance_data_dir='+INSTANCE_DIR+' \
631
+ --output_dir='+OUTPUT_DIR+' \
632
+ --captions_dir='+CAPTIONS_DIR+' \
633
+ --seed='+str(Seed)+' \
634
+ --resolution='+str(Resolution)+' \
635
+ --mixed_precision='+str(precision)+' \
636
+ --train_batch_size=1 \
637
+ --gradient_accumulation_steps=1 '+GC+ ' \
638
+ --use_8bit_adam \
639
+ --learning_rate='+str(Unet_Learning_Rate)+' \
640
+ --lr_scheduler="cosine" \
641
+ --lr_warmup_steps=0 \
642
+ --num_train_epochs='+str(Training_Epochs), shell=True)
643
+
644
+
645
+
646
+ if Unet_Training_Epochs!=0:
647
+ if Text_Encoder_Training_Epochs!=0:
648
+ train_only_text(SESSION_DIR, MODEL_NAME, INSTANCE_DIR, OUTPUT_DIR, Seed, Resolution, ofstnse, extrnlcptn, precision, Training_Epochs=Text_Encoder_Training_Epochs)
649
+ clear_output()
650
+ train_only_unet(SESSION_DIR, MODEL_NAME, INSTANCE_DIR, OUTPUT_DIR, Seed, Resolution, ofstnse, extrnlcptn, precision, Training_Epochs=Unet_Training_Epochs)
651
+ else :
652
+ print('Nothing to do')
653
+
654
+
655
+ if os.path.exists(SESSION_DIR+'/'+Session_Name+'.safetensors'):
656
+ clear_output()
657
+ print("DONE, the LoRa model is in the session's folder")
658
+ else:
659
+ print("Something went wrong")
660
+
661
+
662
+
663
+
664
+ def sdcmff(Huggingface_token_optional, MDLPTH, restored):
665
+
666
+ from slugify import slugify
667
+ from huggingface_hub import HfApi, CommitOperationAdd, create_repo
668
+
669
+ os.chdir('/workspace')
670
+
671
+ if restored:
672
+ Huggingface_token_optional=""
673
+
674
+ if Huggingface_token_optional!="":
675
+ username = HfApi().whoami(Huggingface_token_optional)["name"]
676
+ backup=f"https://huggingface.co/datasets/{username}/fast-stable-diffusion/resolve/main/sdcomfy_backup_rnpd.tar.zst"
677
+ headers = {"Authorization": f"Bearer {Huggingface_token_optional}"}
678
+ response = requests.head(backup, headers=headers)
679
+ if response.status_code == 302:
680
+ restored=True
681
+ print('Restoring ComfyUI...')
682
+ open('/workspace/sdcomfy_backup_rnpd.tar.zst', 'wb').write(requests.get(backup, headers=headers).content)
683
+ call('tar --zstd -xf sdcomfy_backup_rnpd.tar.zst', shell=True)
684
+ call('rm sdcomfy_backup_rnpd.tar.zst', shell=True)
685
+ else:
686
+ print('Backup not found, using a fresh/existing repo...')
687
+ time.sleep(2)
688
+ if not os.path.exists('ComfyUI'):
689
+ call('git clone -q --depth 1 https://github.com/comfyanonymous/ComfyUI', shell=True)
690
+ else:
691
+ print('Installing/Updating the repo...')
692
+ if not os.path.exists('ComfyUI'):
693
+ call('git clone -q --depth 1 https://github.com/comfyanonymous/ComfyUI', shell=True)
694
+
695
+ os.chdir('ComfyUI')
696
+ call('git reset --hard', shell=True)
697
+ print('')
698
+ call('git pull', shell=True)
699
+
700
+ if os.path.exists(MDLPTH):
701
+ call('ln -s '+MDLPTH+' models/loras', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
702
+
703
+ clean_symlinks('models/loras')
704
+
705
+ if not os.path.exists('models/checkpoints/sd_xl_base_1.0.safetensors'):
706
+ call('ln -s /workspace/auto-models/* models/checkpoints', shell=True)
707
+
708
+
709
+ podid=os.environ.get('RUNPOD_POD_ID')
710
+ localurl=f"https://{podid}-3001.proxy.runpod.net"
711
+ call("sed -i 's@print(\"To see the GUI go to: http://{}:{}\".format(address, port))@print(\"\u2714 Connected\")\\n print(\""+localurl+"\")@' /workspace/ComfyUI/server.py", shell=True)
712
+ os.chdir('/workspace')
713
+
714
+ return restored
715
+
716
+
717
+
718
+
719
+ def test(MDLPTH, User, Password, Huggingface_token_optional, restoreda):
720
+
721
+ from slugify import slugify
722
+ from huggingface_hub import HfApi, CommitOperationAdd, create_repo
723
+ import gradio
724
+
725
+ gradio.close_all()
726
+
727
+
728
+ auth=f"--gradio-auth {User}:{Password}"
729
+ if User =="" or Password=="":
730
+ auth=""
731
+
732
+
733
+ if restoreda:
734
+ Huggingface_token_optional=""
735
+
736
+ if Huggingface_token_optional!="":
737
+ username = HfApi().whoami(Huggingface_token_optional)["name"]
738
+ backup=f"https://huggingface.co/datasets/{username}/fast-stable-diffusion/resolve/main/sd_backup_rnpd.tar.zst"
739
+ headers = {"Authorization": f"Bearer {Huggingface_token_optional}"}
740
+ response = requests.head(backup, headers=headers)
741
+ if response.status_code == 302:
742
+ restoreda=True
743
+ print('Restoring the SD folder...')
744
+ open('/workspace/sd_backup_rnpd.tar.zst', 'wb').write(requests.get(backup, headers=headers).content)
745
+ call('tar --zstd -xf sd_backup_rnpd.tar.zst', shell=True)
746
+ call('rm sd_backup_rnpd.tar.zst', shell=True)
747
+ else:
748
+ print('Backup not found, using a fresh/existing repo...')
749
+ time.sleep(2)
750
+ if not os.path.exists('/workspace/sd/stablediffusiond'): #reset later
751
+ call('wget -q -O sd_mrep.tar.zst https://huggingface.co/TheLastBen/dependencies/resolve/main/sd_mrep.tar.zst', shell=True)
752
+ call('tar --zstd -xf sd_mrep.tar.zst', shell=True)
753
+ call('rm sd_mrep.tar.zst', shell=True)
754
+ os.chdir('/workspace/sd')
755
+ if not os.path.exists('stable-diffusion-webui'):
756
+ call('git clone -q --depth 1 --branch master https://github.com/AUTOMATIC1111/stable-diffusion-webui', shell=True)
757
+
758
+ else:
759
+ print('Installing/Updating the repo...')
760
+ os.chdir('/workspace')
761
+ if not os.path.exists('/workspace/sd/stablediffusiond'): #reset later
762
+ call('wget -q -O sd_mrep.tar.zst https://huggingface.co/TheLastBen/dependencies/resolve/main/sd_mrep.tar.zst', shell=True)
763
+ call('tar --zstd -xf sd_mrep.tar.zst', shell=True)
764
+ call('rm sd_mrep.tar.zst', shell=True)
765
+
766
+ os.chdir('/workspace/sd')
767
+ if not os.path.exists('stable-diffusion-webui'):
768
+ call('git clone -q --depth 1 --branch master https://github.com/AUTOMATIC1111/stable-diffusion-webui', shell=True)
769
+
770
+
771
+ os.chdir('/workspace/sd/stable-diffusion-webui/')
772
+ call('git reset --hard', shell=True)
773
+ print('')
774
+ call('git pull', shell=True)
775
+
776
+
777
+ if os.path.exists(MDLPTH):
778
+ call('mkdir models/Lora', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
779
+ call('ln -s '+MDLPTH+' models/Lora', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w'))
780
+
781
+ if not os.path.exists('models/Stable-diffusion/sd_xl_base_1.0.safetensors'):
782
+ call('ln -s /workspace/auto-models/* models/Stable-diffusion', shell=True)
783
+
784
+ clean_symlinks('models/Lora')
785
+
786
+ os.chdir('/workspace')
787
+
788
+
789
+ call('wget -q -O /usr/local/lib/python3.10/dist-packages/gradio/blocks.py https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/main/AUTOMATIC1111_files/blocks.py', shell=True)
790
+
791
+ os.chdir('/workspace/sd/stable-diffusion-webui/modules')
792
+
793
+ call("sed -i 's@possible_sd_paths =.*@possible_sd_paths = [\"/workspace/sd/stablediffusion\"]@' /workspace/sd/stable-diffusion-webui/modules/paths.py", shell=True)
794
+ call("sed -i 's@\.\.\/@src/@g' /workspace/sd/stable-diffusion-webui/modules/paths.py", shell=True)
795
+ call("sed -i 's@src\/generative-models@generative-models@g' /workspace/sd/stable-diffusion-webui/modules/paths.py", shell=True)
796
+
797
+ call("sed -i 's@\[\"sd_model_checkpoint\"\]@\[\"sd_model_checkpoint\", \"sd_vae\", \"CLIP_stop_at_last_layers\", \"inpainting_mask_weight\", \"initial_noise_multiplier\"\]@g' /workspace/sd/stable-diffusion-webui/modules/shared.py", shell=True)
798
+ call("sed -i 's@print(\"No module.*@@' /workspace/sd/stablediffusion/ldm/modules/diffusionmodules/model.py", shell=True)
799
+ os.chdir('/workspace/sd/stable-diffusion-webui')
800
+ clear_output()
801
+
802
+ podid=os.environ.get('RUNPOD_POD_ID')
803
+ localurl=f"{podid}-3001.proxy.runpod.net"
804
+
805
+ for line in fileinput.input('/usr/local/lib/python3.10/dist-packages/gradio/blocks.py', inplace=True):
806
+ if line.strip().startswith('self.server_name ='):
807
+ line = f' self.server_name = "{localurl}"\n'
808
+ if line.strip().startswith('self.protocol = "https"'):
809
+ line = ' self.protocol = "https"\n'
810
+ if line.strip().startswith('if self.local_url.startswith("https") or self.is_colab'):
811
+ line = ''
812
+ if line.strip().startswith('else "http"'):
813
+ line = ''
814
+ sys.stdout.write(line)
815
+
816
+
817
+ configf="--disable-console-progressbars --upcast-sampling --no-half-vae --disable-safe-unpickle --api --opt-sdp-attention --enable-insecure-extension-access --no-download-sd-model --skip-version-check --listen --port 3000 --ckpt /workspace/sd/stable-diffusion-webui/models/Stable-diffusion/sd_xl_base_1.0.safetensors "+auth
818
+
819
+
820
+ return configf, restoreda
821
+
822
+
823
+
824
+
825
+ def clean():
826
+
827
+ Sessions=os.listdir("/workspace/Fast-Dreambooth/Sessions")
828
+
829
+ s = widgets.Select(
830
+ options=Sessions,
831
+ rows=5,
832
+ description='',
833
+ disabled=False
834
+ )
835
+
836
+ out=widgets.Output()
837
+
838
+ d = widgets.Button(
839
+ description='Remove',
840
+ disabled=False,
841
+ button_style='warning',
842
+ tooltip='Removet the selected session',
843
+ icon='warning'
844
+ )
845
+
846
+ def rem(d):
847
+ with out:
848
+ if s.value is not None:
849
+ clear_output()
850
+ print("THE SESSION "+s.value+" HAS BEEN REMOVED FROM THE STORAGE")
851
+ call('rm -r /workspace/Fast-Dreambooth/Sessions/'+s.value, shell=True)
852
+ if os.path.exists('/workspace/models/'+s.value):
853
+ call('rm -r /workspace/models/'+s.value, shell=True)
854
+ s.options=os.listdir("/workspace/Fast-Dreambooth/Sessions")
855
+
856
+
857
+ else:
858
+ d.close()
859
+ s.close()
860
+ clear_output()
861
+ print("NOTHING TO REMOVE")
862
+
863
+ d.on_click(rem)
864
+ if s.value is not None:
865
+ display(s,d,out)
866
+ else:
867
+ print("NOTHING TO REMOVE")
868
+
869
+
870
+
871
+ def crop_image(im, size):
872
+
873
+ import cv2
874
+
875
+ GREEN = "#0F0"
876
+ BLUE = "#00F"
877
+ RED = "#F00"
878
+
879
+ def focal_point(im, settings):
880
+ corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
881
+ entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
882
+ face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
883
+
884
+ pois = []
885
+
886
+ weight_pref_total = 0
887
+ if len(corner_points) > 0:
888
+ weight_pref_total += settings.corner_points_weight
889
+ if len(entropy_points) > 0:
890
+ weight_pref_total += settings.entropy_points_weight
891
+ if len(face_points) > 0:
892
+ weight_pref_total += settings.face_points_weight
893
+
894
+ corner_centroid = None
895
+ if len(corner_points) > 0:
896
+ corner_centroid = centroid(corner_points)
897
+ corner_centroid.weight = settings.corner_points_weight / weight_pref_total
898
+ pois.append(corner_centroid)
899
+
900
+ entropy_centroid = None
901
+ if len(entropy_points) > 0:
902
+ entropy_centroid = centroid(entropy_points)
903
+ entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
904
+ pois.append(entropy_centroid)
905
+
906
+ face_centroid = None
907
+ if len(face_points) > 0:
908
+ face_centroid = centroid(face_points)
909
+ face_centroid.weight = settings.face_points_weight / weight_pref_total
910
+ pois.append(face_centroid)
911
+
912
+ average_point = poi_average(pois, settings)
913
+
914
+ return average_point
915
+
916
+
917
+ def image_face_points(im, settings):
918
+
919
+ np_im = np.array(im)
920
+ gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
921
+
922
+ tries = [
923
+ [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
924
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
925
+ [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
926
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
927
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
928
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
929
+ [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
930
+ [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
931
+ ]
932
+ for t in tries:
933
+ classifier = cv2.CascadeClassifier(t[0])
934
+ minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
935
+ try:
936
+ faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
937
+ minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
938
+ except:
939
+ continue
940
+
941
+ if len(faces) > 0:
942
+ rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
943
+ return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
944
+ return []
945
+
946
+
947
+ def image_corner_points(im, settings):
948
+ grayscale = im.convert("L")
949
+
950
+ # naive attempt at preventing focal points from collecting at watermarks near the bottom
951
+ gd = ImageDraw.Draw(grayscale)
952
+ gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
953
+
954
+ np_im = np.array(grayscale)
955
+
956
+ points = cv2.goodFeaturesToTrack(
957
+ np_im,
958
+ maxCorners=100,
959
+ qualityLevel=0.04,
960
+ minDistance=min(grayscale.width, grayscale.height)*0.06,
961
+ useHarrisDetector=False,
962
+ )
963
+
964
+ if points is None:
965
+ return []
966
+
967
+ focal_points = []
968
+ for point in points:
969
+ x, y = point.ravel()
970
+ focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
971
+
972
+ return focal_points
973
+
974
+
975
+ def image_entropy_points(im, settings):
976
+ landscape = im.height < im.width
977
+ portrait = im.height > im.width
978
+ if landscape:
979
+ move_idx = [0, 2]
980
+ move_max = im.size[0]
981
+ elif portrait:
982
+ move_idx = [1, 3]
983
+ move_max = im.size[1]
984
+ else:
985
+ return []
986
+
987
+ e_max = 0
988
+ crop_current = [0, 0, settings.crop_width, settings.crop_height]
989
+ crop_best = crop_current
990
+ while crop_current[move_idx[1]] < move_max:
991
+ crop = im.crop(tuple(crop_current))
992
+ e = image_entropy(crop)
993
+
994
+ if (e > e_max):
995
+ e_max = e
996
+ crop_best = list(crop_current)
997
+
998
+ crop_current[move_idx[0]] += 4
999
+ crop_current[move_idx[1]] += 4
1000
+
1001
+ x_mid = int(crop_best[0] + settings.crop_width/2)
1002
+ y_mid = int(crop_best[1] + settings.crop_height/2)
1003
+
1004
+ return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
1005
+
1006
+
1007
+ def image_entropy(im):
1008
+ # greyscale image entropy
1009
+ # band = np.asarray(im.convert("L"))
1010
+ band = np.asarray(im.convert("1"), dtype=np.uint8)
1011
+ hist, _ = np.histogram(band, bins=range(0, 256))
1012
+ hist = hist[hist > 0]
1013
+ return -np.log2(hist / hist.sum()).sum()
1014
+
1015
+ def centroid(pois):
1016
+ x = [poi.x for poi in pois]
1017
+ y = [poi.y for poi in pois]
1018
+ return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
1019
+
1020
+
1021
+ def poi_average(pois, settings):
1022
+ weight = 0.0
1023
+ x = 0.0
1024
+ y = 0.0
1025
+ for poi in pois:
1026
+ weight += poi.weight
1027
+ x += poi.x * poi.weight
1028
+ y += poi.y * poi.weight
1029
+ avg_x = round(weight and x / weight)
1030
+ avg_y = round(weight and y / weight)
1031
+
1032
+ return PointOfInterest(avg_x, avg_y)
1033
+
1034
+
1035
+ def is_landscape(w, h):
1036
+ return w > h
1037
+
1038
+
1039
+ def is_portrait(w, h):
1040
+ return h > w
1041
+
1042
+
1043
+ def is_square(w, h):
1044
+ return w == h
1045
+
1046
+
1047
+ class PointOfInterest:
1048
+ def __init__(self, x, y, weight=1.0, size=10):
1049
+ self.x = x
1050
+ self.y = y
1051
+ self.weight = weight
1052
+ self.size = size
1053
+
1054
+ def bounding(self, size):
1055
+ return [
1056
+ self.x - size//2,
1057
+ self.y - size//2,
1058
+ self.x + size//2,
1059
+ self.y + size//2
1060
+ ]
1061
+
1062
+ class Settings:
1063
+ def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5):
1064
+ self.crop_width = crop_width
1065
+ self.crop_height = crop_height
1066
+ self.corner_points_weight = corner_points_weight
1067
+ self.entropy_points_weight = entropy_points_weight
1068
+ self.face_points_weight = face_points_weight
1069
+
1070
+ settings = Settings(
1071
+ crop_width = size,
1072
+ crop_height = size,
1073
+ face_points_weight = 0.9,
1074
+ entropy_points_weight = 0.15,
1075
+ corner_points_weight = 0.5,
1076
+ )
1077
+
1078
+ scale_by = 1
1079
+ if is_landscape(im.width, im.height):
1080
+ scale_by = settings.crop_height / im.height
1081
+ elif is_portrait(im.width, im.height):
1082
+ scale_by = settings.crop_width / im.width
1083
+ elif is_square(im.width, im.height):
1084
+ if is_square(settings.crop_width, settings.crop_height):
1085
+ scale_by = settings.crop_width / im.width
1086
+ elif is_landscape(settings.crop_width, settings.crop_height):
1087
+ scale_by = settings.crop_width / im.width
1088
+ elif is_portrait(settings.crop_width, settings.crop_height):
1089
+ scale_by = settings.crop_height / im.height
1090
+
1091
+ im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
1092
+ im_debug = im.copy()
1093
+
1094
+ focus = focal_point(im_debug, settings)
1095
+
1096
+ # take the focal point and turn it into crop coordinates that try to center over the focal
1097
+ # point but then get adjusted back into the frame
1098
+ y_half = int(settings.crop_height / 2)
1099
+ x_half = int(settings.crop_width / 2)
1100
+
1101
+ x1 = focus.x - x_half
1102
+ if x1 < 0:
1103
+ x1 = 0
1104
+ elif x1 + settings.crop_width > im.width:
1105
+ x1 = im.width - settings.crop_width
1106
+
1107
+ y1 = focus.y - y_half
1108
+ if y1 < 0:
1109
+ y1 = 0
1110
+ elif y1 + settings.crop_height > im.height:
1111
+ y1 = im.height - settings.crop_height
1112
+
1113
+ x2 = x1 + settings.crop_width
1114
+ y2 = y1 + settings.crop_height
1115
+
1116
+ crop = [x1, y1, x2, y2]
1117
+
1118
+ results = []
1119
+
1120
+ results.append(im.crop(tuple(crop)))
1121
+
1122
+ return results
1123
+
1124
+
1125
+
1126
+ def clean_symlinks(path):
1127
+ for item in os.listdir(path):
1128
+ lnk = os.path.join(path, item)
1129
+ if os.path.islink(lnk) and not os.path.exists(os.readlink(lnk)):
1130
+ os.remove(lnk)
1131
+
train_dreambooth_rnpd_sdxl_lora.py ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import math
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Optional
7
+ import subprocess
8
+ import sys
9
+
10
+ import gc
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint
14
+ from torch.utils.data import Dataset
15
+ from transformers import AutoTokenizer, PretrainedConfig
16
+ import bitsandbytes as bnb
17
+
18
+ from accelerate import Accelerator
19
+ from accelerate.logging import get_logger
20
+ from accelerate.utils import set_seed
21
+ from contextlib import nullcontext
22
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
23
+ from diffusers.optimization import get_scheduler
24
+ from huggingface_hub import HfFolder, Repository, whoami
25
+ from PIL import Image
26
+ from torchvision import transforms
27
+ from tqdm import tqdm
28
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, CLIPTextModelWithProjection
29
+
30
+ from lora_sdxl import *
31
+
32
+ logger = get_logger(__name__)
33
+
34
+
35
+ def import_model_class_from_model_name_or_path(
36
+ pretrained_model_name_or_path: str, subfolder: str = "text_encoder"
37
+ ):
38
+ text_encoder_config = PretrainedConfig.from_pretrained(
39
+ pretrained_model_name_or_path,
40
+ subfolder=subfolder,
41
+ use_auth_token=True
42
+ )
43
+ model_class = text_encoder_config.architectures[0]
44
+
45
+ if model_class == "CLIPTextModel":
46
+ from transformers import CLIPTextModel
47
+
48
+ return CLIPTextModel
49
+ elif model_class == "CLIPTextModelWithProjection":
50
+ from transformers import CLIPTextModelWithProjection
51
+
52
+ return CLIPTextModelWithProjection
53
+ else:
54
+ raise ValueError(f"{model_class} is not supported.")
55
+
56
+
57
+ def parse_args():
58
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
59
+ parser.add_argument(
60
+ "--pretrained_model_name_or_path",
61
+ type=str,
62
+ default=None,
63
+ required=True,
64
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
65
+ )
66
+ parser.add_argument(
67
+ "--tokenizer_name",
68
+ type=str,
69
+ default=None,
70
+ help="Pretrained tokenizer name or path if not the same as model_name",
71
+ )
72
+ parser.add_argument(
73
+ "--instance_data_dir",
74
+ type=str,
75
+ default=None,
76
+ required=True,
77
+ help="A folder containing the training data of instance images.",
78
+ )
79
+ parser.add_argument(
80
+ "--class_data_dir",
81
+ type=str,
82
+ default=None,
83
+ required=False,
84
+ help="A folder containing the training data of class images.",
85
+ )
86
+ parser.add_argument(
87
+ "--instance_prompt",
88
+ type=str,
89
+ default=None,
90
+ help="The prompt with identifier specifying the instance",
91
+ )
92
+ parser.add_argument(
93
+ "--class_prompt",
94
+ type=str,
95
+ default="",
96
+ help="The prompt to specify images in the same class as provided instance images.",
97
+ )
98
+ parser.add_argument(
99
+ "--with_prior_preservation",
100
+ default=False,
101
+ action="store_true",
102
+ help="Flag to add prior preservation loss.",
103
+ )
104
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
105
+ parser.add_argument(
106
+ "--num_class_images",
107
+ type=int,
108
+ default=100,
109
+ help=(
110
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
111
+ " sampled with class_prompt."
112
+ ),
113
+ )
114
+ parser.add_argument(
115
+ "--output_dir",
116
+ type=str,
117
+ default="",
118
+ help="The output directory where the model predictions and checkpoints will be written.",
119
+ )
120
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
121
+ parser.add_argument(
122
+ "--resolution",
123
+ type=int,
124
+ default=512,
125
+ help=(
126
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
127
+ " resolution"
128
+ ),
129
+ )
130
+ parser.add_argument(
131
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
132
+ )
133
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
134
+ parser.add_argument(
135
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
136
+ )
137
+ parser.add_argument(
138
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
139
+ )
140
+ parser.add_argument("--num_train_epochs", type=int, default=1)
141
+ parser.add_argument(
142
+ "--max_train_steps",
143
+ type=int,
144
+ default=None,
145
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
146
+ )
147
+ parser.add_argument(
148
+ "--gradient_accumulation_steps",
149
+ type=int,
150
+ default=1,
151
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
152
+ )
153
+ parser.add_argument(
154
+ "--gradient_checkpointing",
155
+ action="store_true",
156
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
157
+ )
158
+ parser.add_argument(
159
+ "--learning_rate",
160
+ type=float,
161
+ default=5e-6,
162
+ help="Initial learning rate (after the potential warmup period) to use.",
163
+ )
164
+ parser.add_argument(
165
+ "--scale_lr",
166
+ action="store_true",
167
+ default=False,
168
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
169
+ )
170
+ parser.add_argument(
171
+ "--lr_scheduler",
172
+ type=str,
173
+ default="constant",
174
+ help=(
175
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
176
+ ' "constant", "constant_with_warmup"]'
177
+ ),
178
+ )
179
+ parser.add_argument(
180
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
181
+ )
182
+ parser.add_argument(
183
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
184
+ )
185
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
186
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
187
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
188
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
189
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
190
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
191
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
192
+ parser.add_argument(
193
+ "--hub_model_id",
194
+ type=str,
195
+ default=None,
196
+ help="The name of the repository to keep in sync with the local `output_dir`.",
197
+ )
198
+ parser.add_argument(
199
+ "--logging_dir",
200
+ type=str,
201
+ default="logs",
202
+ help=(
203
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
204
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
205
+ ),
206
+ )
207
+ parser.add_argument(
208
+ "--mixed_precision",
209
+ type=str,
210
+ default="no",
211
+ choices=["no", "fp16", "bf16"],
212
+ help=(
213
+ "Whether to use mixed precision. Choose"
214
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
215
+ "and an Nvidia Ampere GPU."
216
+ ),
217
+ )
218
+
219
+ parser.add_argument(
220
+ "--save_n_steps",
221
+ type=int,
222
+ default=1,
223
+ help=("Save the model every n global_steps"),
224
+ )
225
+
226
+
227
+ parser.add_argument(
228
+ "--save_starting_step",
229
+ type=int,
230
+ default=1,
231
+ help=("The step from which it starts saving intermediary checkpoints"),
232
+ )
233
+
234
+ parser.add_argument(
235
+ "--stop_text_encoder_training",
236
+ type=int,
237
+ default=1000000,
238
+ help=("The step at which the text_encoder is no longer trained"),
239
+ )
240
+
241
+
242
+ parser.add_argument(
243
+ "--image_captions_filename",
244
+ action="store_true",
245
+ help="Get captions from filename",
246
+ )
247
+
248
+
249
+
250
+ parser.add_argument(
251
+ "--Resumetr",
252
+ type=str,
253
+ default="False",
254
+ help="Resume training info",
255
+ )
256
+
257
+
258
+
259
+ parser.add_argument(
260
+ "--Session_dir",
261
+ type=str,
262
+ default="",
263
+ help="Current session directory",
264
+ )
265
+
266
+ parser.add_argument(
267
+ "--external_captions",
268
+ action="store_true",
269
+ default=False,
270
+ help="Use captions stored in a txt file",
271
+ )
272
+
273
+ parser.add_argument(
274
+ "--captions_dir",
275
+ type=str,
276
+ default="",
277
+ help="The folder where captions files are stored",
278
+ )
279
+
280
+ parser.add_argument(
281
+ "--offset_noise",
282
+ action="store_true",
283
+ default=False,
284
+ help="Offset Noise",
285
+ )
286
+
287
+ parser.add_argument(
288
+ "--ofstnselvl",
289
+ type=float,
290
+ default=0.03,
291
+ help="Offset Noise amount",
292
+ )
293
+
294
+ parser.add_argument(
295
+ "--resume",
296
+ action="store_true",
297
+ default=False,
298
+ help="resume training",
299
+ )
300
+
301
+ parser.add_argument(
302
+ "--dim",
303
+ type=int,
304
+ default=64,
305
+ help="LoRa dimension",
306
+ )
307
+
308
+ args = parser.parse_args()
309
+
310
+ return args
311
+
312
+
313
+
314
+ class DreamBoothDataset(Dataset):
315
+ """
316
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
317
+ It pre-processes the images and the tokenizes prompts.
318
+ """
319
+
320
+ def __init__(
321
+ self,
322
+ instance_data_root,
323
+ args,
324
+ tokenizers,
325
+ text_encoders,
326
+ size=512,
327
+ center_crop=False,
328
+ instance_prompt_hidden_states=None,
329
+ instance_unet_added_conditions=None,
330
+ ):
331
+ self.size = size
332
+ self.tokenizers=tokenizers
333
+ self.text_encoders=text_encoders
334
+ self.center_crop = center_crop
335
+ self.instance_prompt_hidden_states = instance_prompt_hidden_states
336
+ self.instance_unet_added_conditions = instance_unet_added_conditions
337
+ self.image_captions_filename = None
338
+
339
+ self.instance_data_root = Path(instance_data_root)
340
+ if not self.instance_data_root.exists():
341
+ raise ValueError("Instance images root doesn't exists.")
342
+
343
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
344
+ self.num_instance_images = len(self.instance_images_path)
345
+ self._length = self.num_instance_images
346
+
347
+ if args.image_captions_filename:
348
+ self.image_captions_filename = True
349
+
350
+ self.image_transforms = transforms.Compose(
351
+ [
352
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
353
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
354
+ transforms.ToTensor(),
355
+ transforms.Normalize([0.5], [0.5]),
356
+ ]
357
+ )
358
+
359
+ def __len__(self):
360
+ return self._length
361
+
362
+ def __getitem__(self, index, args=parse_args()):
363
+ example = {}
364
+ path = self.instance_images_path[index % self.num_instance_images]
365
+ instance_image = Image.open(path)
366
+ if not instance_image.mode == "RGB":
367
+ instance_image = instance_image.convert("RGB")
368
+
369
+ if self.image_captions_filename:
370
+ filename = Path(path).stem
371
+
372
+ pt=''.join([i for i in filename if not i.isdigit()])
373
+ pt=pt.replace("_"," ")
374
+ pt=pt.replace("(","")
375
+ pt=pt.replace(")","")
376
+ pt=pt.replace("-","")
377
+ pt=pt.replace("conceptimagedb","")
378
+
379
+ if args.external_captions:
380
+ cptpth=os.path.join(args.captions_dir, filename+'.txt')
381
+ if os.path.exists(cptpth):
382
+ with open(cptpth, "r") as f:
383
+ instance_prompt=f.read()
384
+ else:
385
+ instance_prompt=pt
386
+ else:
387
+ instance_prompt = pt
388
+
389
+ example["instance_images"] = self.image_transforms(instance_image)
390
+ with torch.no_grad():
391
+ example["instance_prompt_ids"], example["instance_added_cond_kwargs"]= compute_embeddings(args, instance_prompt, self.text_encoders, self.tokenizers)
392
+
393
+ return example
394
+
395
+
396
+ class PromptDataset(Dataset):
397
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
398
+
399
+ def __init__(self, prompt, num_samples):
400
+ self.prompt = prompt
401
+ self.num_samples = num_samples
402
+
403
+ def __len__(self):
404
+ return self.num_samples
405
+
406
+ def __getitem__(self, index):
407
+ example = {}
408
+ example["prompt"] = self.prompt
409
+ example["index"] = index
410
+ return example
411
+
412
+
413
+ def encode_prompt(text_encoders, tokenizers, prompt):
414
+ prompt_embeds_list = []
415
+
416
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
417
+ text_inputs = tokenizer(
418
+ prompt,
419
+ padding="max_length",
420
+ max_length=tokenizer.model_max_length,
421
+ truncation=True,
422
+ return_tensors="pt",
423
+ )
424
+ text_input_ids = text_inputs.input_ids
425
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
426
+
427
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
428
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
429
+ logger.warning(
430
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
431
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
432
+ )
433
+
434
+ with torch.no_grad():
435
+ prompt_embeds = text_encoder(
436
+ text_input_ids.to(text_encoder.device),
437
+ output_hidden_states=True,
438
+ )
439
+
440
+ # We are only ALWAYS interested in the pooled output of the final text encoder
441
+ pooled_prompt_embeds = prompt_embeds[0]
442
+ prompt_embeds = prompt_embeds.hidden_states[-2]
443
+ bs_embed, seq_len, _ = prompt_embeds.shape
444
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
445
+ prompt_embeds_list.append(prompt_embeds)
446
+
447
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
448
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
449
+ return prompt_embeds, pooled_prompt_embeds
450
+
451
+
452
+ def collate_fn(examples):
453
+
454
+ input_ids = [example["instance_prompt_ids"] for example in examples]
455
+ pixel_values = [example["instance_images"] for example in examples]
456
+ add_text_embeds = [example["instance_added_cond_kwargs"]["text_embeds"] for example in examples]
457
+ add_time_ids = [example["instance_added_cond_kwargs"]["time_ids"] for example in examples]
458
+
459
+ pixel_values = torch.stack(pixel_values)
460
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).half()
461
+
462
+ input_ids = torch.cat(input_ids, dim=0)
463
+ add_text_embeds = torch.cat(add_text_embeds, dim=0)
464
+ add_time_ids = torch.cat(add_time_ids, dim=0)
465
+
466
+ batch = {
467
+ "input_ids": input_ids,
468
+ "pixel_values": pixel_values,
469
+ "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids},
470
+ }
471
+
472
+ return batch
473
+
474
+
475
+ def compute_embeddings(args, prompt, text_encoders, tokenizers):
476
+ original_size = (args.resolution, args.resolution)
477
+ target_size = (args.resolution, args.resolution)
478
+ crops_coords_top_left = (0, 0)
479
+
480
+ with torch.no_grad():
481
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
482
+ add_text_embeds = pooled_prompt_embeds
483
+
484
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
485
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
486
+ add_time_ids = torch.tensor([add_time_ids])
487
+
488
+ prompt_embeds = prompt_embeds.to('cuda')
489
+ add_text_embeds = add_text_embeds.to('cuda')
490
+ add_time_ids = add_time_ids.to('cuda', dtype=prompt_embeds.dtype)
491
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
492
+
493
+ return prompt_embeds, unet_added_cond_kwargs
494
+
495
+
496
+ class LatentsDataset(Dataset):
497
+ def __init__(self, latents_cache, text_encoder_cache, cond_cache):
498
+ self.latents_cache = latents_cache
499
+ self.text_encoder_cache = text_encoder_cache
500
+ self.cond_cache = cond_cache
501
+
502
+ def __len__(self):
503
+ return len(self.latents_cache)
504
+
505
+ def __getitem__(self, index):
506
+ return self.latents_cache[index], self.text_encoder_cache[index], self.cond_cache[index]
507
+
508
+
509
+
510
+ def main():
511
+ args = parse_args()
512
+ logging_dir = Path(args.output_dir, args.logging_dir)
513
+
514
+ accelerator = Accelerator(
515
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
516
+ mixed_precision=args.mixed_precision,
517
+ log_with="tensorboard",
518
+ logging_dir=logging_dir,
519
+ )
520
+
521
+
522
+ if args.seed is not None:
523
+ set_seed(args.seed)
524
+
525
+ # Handle the repository creation
526
+ if accelerator.is_main_process:
527
+ if args.output_dir is not None:
528
+ os.makedirs(args.output_dir, exist_ok=True)
529
+
530
+ # Load the tokenizers
531
+ tokenizer_one = AutoTokenizer.from_pretrained(
532
+ args.pretrained_model_name_or_path,
533
+ subfolder="tokenizer",
534
+ use_fast=False,
535
+ use_auth_token=True,
536
+ )
537
+ tokenizer_two = AutoTokenizer.from_pretrained(
538
+ args.pretrained_model_name_or_path,
539
+ subfolder="tokenizer_2",
540
+ use_fast=False,
541
+ use_auth_token=True
542
+ )
543
+
544
+
545
+
546
+ # import correct text encoder classes
547
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
548
+ args.pretrained_model_name_or_path, subfolder="text_encoder"
549
+ )
550
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
551
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2"
552
+ )
553
+
554
+ # Load scheduler and models
555
+
556
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
557
+ args.pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=True,
558
+ )
559
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
560
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", use_auth_token=True
561
+ )
562
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", use_auth_token=True)
563
+ unet = UNet2DConditionModel.from_pretrained(
564
+ args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=True
565
+ )
566
+
567
+ vae.requires_grad_(False)
568
+ text_encoder_one.requires_grad_(False)
569
+ text_encoder_two.requires_grad_(False)
570
+ unet.requires_grad_(False)
571
+ text_encoder_one.eval()
572
+ text_encoder_two.eval()
573
+ vae.eval()
574
+
575
+ model_path = os.path.join(args.Session_dir, os.path.basename(args.Session_dir) + ".safetensors")
576
+ network = create_network(1, args.dim, 20000, unet)
577
+ if args.resume:
578
+ network.load_weights(model_path)
579
+
580
+ def set_diffusers_xformers_flag(model, valid):
581
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
582
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
583
+ module.set_use_memory_efficient_attention_xformers(valid)
584
+
585
+ for child in module.children():
586
+ fn_recursive_set_mem_eff(child)
587
+
588
+ fn_recursive_set_mem_eff(model)
589
+
590
+ set_diffusers_xformers_flag(unet, True)
591
+
592
+ network.apply_to(unet, True)
593
+ trainable_params = network.parameters()
594
+
595
+ tokenizers = [tokenizer_one, tokenizer_two]
596
+ text_encoders = [text_encoder_one, text_encoder_two]
597
+
598
+
599
+ if args.gradient_checkpointing:
600
+ unet.enable_gradient_checkpointing()
601
+
602
+ if args.scale_lr:
603
+ args.learning_rate = (
604
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
605
+ )
606
+
607
+ optimizer_class = bnb.optim.AdamW8bit
608
+
609
+ optimizer = optimizer_class(
610
+ trainable_params,
611
+ lr=args.learning_rate,
612
+ betas=(args.adam_beta1, args.adam_beta2),
613
+ weight_decay=args.adam_weight_decay,
614
+ eps=args.adam_epsilon,
615
+ )
616
+
617
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler", use_auth_token=True)
618
+
619
+ train_dataset = DreamBoothDataset(
620
+ instance_data_root=args.instance_data_dir,
621
+ tokenizers=tokenizers,
622
+ text_encoders=text_encoders,
623
+ size=args.resolution,
624
+ center_crop=args.center_crop,
625
+ args=args
626
+ )
627
+
628
+ train_dataloader = torch.utils.data.DataLoader(
629
+ train_dataset,
630
+ batch_size=args.train_batch_size,
631
+ shuffle=True,
632
+ collate_fn=lambda examples: collate_fn(examples),
633
+ )
634
+
635
+
636
+ # Scheduler and math around the number of training steps.
637
+ overrode_max_train_steps = False
638
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
639
+ if args.max_train_steps is None:
640
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
641
+ overrode_max_train_steps = True
642
+
643
+ lr_scheduler = get_scheduler(
644
+ args.lr_scheduler,
645
+ optimizer=optimizer,
646
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
647
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
648
+ )
649
+
650
+
651
+ network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
652
+ network, optimizer, train_dataloader, lr_scheduler)
653
+
654
+ weight_dtype = torch.float32
655
+ if args.mixed_precision == "fp16":
656
+ weight_dtype = torch.float16
657
+ elif args.mixed_precision == "bf16":
658
+ weight_dtype = torch.bfloat16
659
+
660
+ unet.to(accelerator.device, dtype=weight_dtype)
661
+ vae.to(accelerator.device, dtype=weight_dtype)
662
+ network.prepare_grad_etc(network)
663
+
664
+
665
+ latents_cache = []
666
+ text_encoder_cache = []
667
+ cond_cache= []
668
+ for batch in train_dataloader:
669
+ with torch.no_grad():
670
+
671
+ batch["input_ids"] = batch["input_ids"].to(accelerator.device, non_blocking=True)
672
+ batch["unet_added_conditions"] = batch["unet_added_conditions"]
673
+
674
+ batch["pixel_values"]=(vae.encode(batch["pixel_values"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample() * vae.config.scaling_factor)
675
+
676
+ latents_cache.append(batch["pixel_values"])
677
+ text_encoder_cache.append(batch["input_ids"])
678
+ cond_cache.append(batch["unet_added_conditions"])
679
+
680
+ train_dataset = LatentsDataset(latents_cache, text_encoder_cache, cond_cache)
681
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True)
682
+
683
+ del vae, tokenizers, text_encoders
684
+ gc.collect()
685
+ torch.cuda.empty_cache()
686
+
687
+
688
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
689
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
690
+ if overrode_max_train_steps:
691
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
692
+ # Afterwards we recalculate our number of training epochs
693
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
694
+
695
+ # We need to initialize the trackers we use, and also store our configuration.
696
+ # The trackers initializes automatically on the main process.
697
+ if accelerator.is_main_process:
698
+ accelerator.init_trackers("dreambooth", config=vars(args))
699
+
700
+ def bar(prg):
701
+ br='|'+'█' * prg + ' ' * (25-prg)+'|'
702
+ return br
703
+
704
+ # Train!
705
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
706
+ text_enc_context = nullcontext() if args.train_text_encoder else torch.no_grad()
707
+ logger.info("***** Running training *****")
708
+ logger.info(f" Num examples = {len(train_dataset)}")
709
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
710
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
711
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
712
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
713
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
714
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
715
+ # Only show the progress bar once on each machine.
716
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
717
+ global_step = 0
718
+
719
+ for epoch in range(args.num_train_epochs):
720
+ unet.train()
721
+ network.train()
722
+ for step, batch in enumerate(train_dataloader):
723
+ with accelerator.accumulate(unet):
724
+
725
+ with torch.no_grad():
726
+ model_input = batch[0][0]
727
+
728
+ # Sample noise that we'll add to the latents
729
+ if args.offset_noise:
730
+ noise = torch.randn_like(model_input)# + args.ofstnselvl * torch.randn(model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device)
731
+ else:
732
+ noise = torch.randn_like(model_input)
733
+
734
+ bsz = model_input.shape[0]
735
+
736
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device)
737
+ timesteps = timesteps.long()
738
+
739
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
740
+
741
+ # Predict the noise residual
742
+ with accelerator.autocast():
743
+ model_pred = unet(noisy_model_input, timesteps, batch[0][1], added_cond_kwargs=batch[0][2]).sample
744
+
745
+ # Get the target for loss depending on the prediction type
746
+ target = noise
747
+
748
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
749
+
750
+ accelerator.backward(loss)
751
+ optimizer.step()
752
+ lr_scheduler.step()
753
+ optimizer.zero_grad(set_to_none=True)
754
+
755
+ # Checks if the accelerator has performed an optimization step behind the scenes
756
+ if accelerator.sync_gradients:
757
+ progress_bar.update(1)
758
+ global_step += 1
759
+
760
+ fll=round((global_step*100)/args.max_train_steps)
761
+ fll=round(fll/4)
762
+ pr=bar(fll)
763
+
764
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
765
+ progress_bar.set_postfix(**logs)
766
+ progress_bar.set_description_str("Progress")
767
+ accelerator.log(logs, step=global_step)
768
+
769
+ if global_step >= args.max_train_steps:
770
+ break
771
+
772
+
773
+ accelerator.wait_for_everyone()
774
+ if accelerator.is_main_process:
775
+ network = accelerator.unwrap_model(network)
776
+ accelerator.end_training()
777
+ network.save_weights(model_path, torch.float16, None)
778
+
779
+ accelerator.end_training()
780
+
781
+ if __name__ == "__main__":
782
+ main()