ManglerFTW commited on
Commit
3a18eba
1 Parent(s): ba2acac

Upload 12 files

Browse files
StableTuner_RunPod_Fix/captionBuddy.py ADDED
@@ -0,0 +1,967 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tkinter as tk
2
+ from tkinter import ttk, Menu
3
+ import os
4
+ import subprocess
5
+ from PIL import Image, ImageTk, ImageDraw
6
+ import tkinter.filedialog as fd
7
+ import json
8
+ import sys
9
+ import os
10
+ import sys
11
+ from torchvision import transforms
12
+ from torchvision.transforms.functional import InterpolationMode
13
+ import torch
14
+ import subprocess
15
+ import numpy as np
16
+ import requests
17
+ import random
18
+ import customtkinter as ctk
19
+ from customtkinter import ThemeManager
20
+
21
+ from clip_segmentation import ClipSeg
22
+
23
+ #main class
24
+ ctk.set_appearance_mode("dark")
25
+ ctk.set_default_color_theme("blue")
26
+
27
+ class BatchMaskWindow(ctk.CTkToplevel):
28
+ def __init__(self, parent, path, *args, **kwargs):
29
+ ctk.CTkToplevel.__init__(self, parent, *args, **kwargs)
30
+ self.parent = parent
31
+
32
+ self.title("Batch process masks")
33
+ self.geometry("320x310")
34
+ self.resizable(False, False)
35
+ self.wait_visibility()
36
+ self.grab_set()
37
+ self.focus_set()
38
+
39
+ self.mode_var = tk.StringVar(self, "Create if absent")
40
+ self.modes = ["Replace all masks", "Create if absent", "Add to existing", "Subtract from existing"]
41
+
42
+ self.frame = ctk.CTkFrame(self, width=600, height=300)
43
+ self.frame.grid(row=0, column=0, sticky="nsew", padx=10, pady=10)
44
+
45
+ self.path_label = ctk.CTkLabel(self.frame, text="Folder", width=100)
46
+ self.path_label.grid(row=0, column=0, sticky="w",padx=5, pady=5)
47
+ self.path_entry = ctk.CTkEntry(self.frame, width=150)
48
+ self.path_entry.insert(0, path)
49
+ self.path_entry.grid(row=0, column=1, sticky="w", padx=5, pady=5)
50
+ self.path_button = ctk.CTkButton(self.frame, width=30, text="...", command=lambda: self.browse_for_path(self.path_entry))
51
+ self.path_button.grid(row=0, column=1, sticky="e", padx=5, pady=5)
52
+
53
+ self.prompt_label = ctk.CTkLabel(self.frame, text="Prompt", width=100)
54
+ self.prompt_label.grid(row=1, column=0, sticky="w",padx=5, pady=5)
55
+ self.prompt_entry = ctk.CTkEntry(self.frame, width=200)
56
+ self.prompt_entry.grid(row=1, column=1, sticky="w", padx=5, pady=5)
57
+
58
+ self.mode_label = ctk.CTkLabel(self.frame, text="Mode", width=100)
59
+ self.mode_label.grid(row=2, column=0, sticky="w", padx=5, pady=5)
60
+ self.mode_dropdown = ctk.CTkOptionMenu(self.frame, variable=self.mode_var, values=self.modes, dynamic_resizing=False, width=200)
61
+ self.mode_dropdown.grid(row=2, column=1, sticky="w", padx=5, pady=5)
62
+
63
+ self.threshold_label = ctk.CTkLabel(self.frame, text="Threshold", width=100)
64
+ self.threshold_label.grid(row=3, column=0, sticky="w", padx=5, pady=5)
65
+ self.threshold_entry = ctk.CTkEntry(self.frame, width=200, placeholder_text="0.0 - 1.0")
66
+ self.threshold_entry.insert(0, "0.3")
67
+ self.threshold_entry.grid(row=3, column=1, sticky="w", padx=5, pady=5)
68
+
69
+ self.smooth_label = ctk.CTkLabel(self.frame, text="Smooth", width=100)
70
+ self.smooth_label.grid(row=4, column=0, sticky="w", padx=5, pady=5)
71
+ self.smooth_entry = ctk.CTkEntry(self.frame, width=200, placeholder_text="5")
72
+ self.smooth_entry.insert(0, 5)
73
+ self.smooth_entry.grid(row=4, column=1, sticky="w", padx=5, pady=5)
74
+
75
+ self.expand_label = ctk.CTkLabel(self.frame, text="Expand", width=100)
76
+ self.expand_label.grid(row=5, column=0, sticky="w", padx=5, pady=5)
77
+ self.expand_entry = ctk.CTkEntry(self.frame, width=200, placeholder_text="10")
78
+ self.expand_entry.insert(0, 10)
79
+ self.expand_entry.grid(row=5, column=1, sticky="w", padx=5, pady=5)
80
+
81
+ self.progress_label = ctk.CTkLabel(self.frame, text="Progress: 0/0", width=100)
82
+ self.progress_label.grid(row=6, column=0, sticky="w", padx=5, pady=5)
83
+ self.progress = ctk.CTkProgressBar(self.frame, orientation="horizontal", mode="determinate", width=200)
84
+ self.progress.grid(row=6, column=1, sticky="w", padx=5, pady=5)
85
+
86
+ self.create_masks_button = ctk.CTkButton(self.frame, text="Create Masks", width=310, command=self.create_masks)
87
+ self.create_masks_button.grid(row=7, column=0, columnspan=2, sticky="w", padx=5, pady=5)
88
+
89
+ self.frame.pack(fill="both", expand=True)
90
+
91
+ def browse_for_path(self, entry_box):
92
+ # get the path from the user
93
+ path = fd.askdirectory()
94
+ # set the path to the entry box
95
+ # delete entry box text
96
+ entry_box.focus_set()
97
+ entry_box.delete(0, tk.END)
98
+ entry_box.insert(0, path)
99
+ self.focus_set()
100
+
101
+ def set_progress(self, value, max_value):
102
+ progress = value / max_value
103
+ self.progress.set(progress)
104
+ self.progress_label.configure(text="{0}/{1}".format(value, max_value))
105
+ self.progress.update()
106
+
107
+ def create_masks(self):
108
+ self.parent.load_clip_seg_model()
109
+
110
+ mode = {
111
+ "Replace all masks": "replace",
112
+ "Create if absent": "fill",
113
+ "Add to existing": "add",
114
+ "Subtract from existing": "subtract"
115
+ }[self.mode_var.get()]
116
+
117
+ self.parent.clip_seg.mask_folder(
118
+ sample_dir=self.path_entry.get(),
119
+ prompts=[self.prompt_entry.get()],
120
+ mode=mode,
121
+ threshold=float(self.threshold_entry.get()),
122
+ smooth_pixels=int(self.smooth_entry.get()),
123
+ expand_pixels=int(self.expand_entry.get()),
124
+ progress_callback=self.set_progress,
125
+ )
126
+ self.parent.load_image()
127
+
128
+
129
+ def _check_file_type(f: str) -> bool:
130
+ return f.lower().endswith(('.jpg', '.jpeg', '.png', '.webp', ".bmp", ".tiff"))
131
+
132
+
133
+ class ImageBrowser(ctk.CTkToplevel):
134
+ def __init__(self,mainProcess=None):
135
+ super().__init__()
136
+ if not os.path.exists("scripts/BLIP"):
137
+ print("Getting BLIP from GitHub.")
138
+ subprocess.run(["git", "clone", "https://github.com/salesforce/BLIP", "scripts/BLIP"])
139
+ #if not os.path.exists("scripts/CLIP"):
140
+ # print("Getting CLIP from GitHub.")
141
+ # subprocess.run(["git", "clone", "https://github.com/pharmapsychotic/clip-interrogator.git', 'scripts/CLIP"])
142
+ blip_path = "scripts/BLIP"
143
+ sys.path.append(blip_path)
144
+ #clip_path = "scripts/CLIP"
145
+ #sys.path.append(clip_path)
146
+ self.mainProcess = mainProcess
147
+ self.captioner_folder = os.path.dirname(os.path.realpath(__file__))
148
+ self.clip_seg = None
149
+ self.PILimage = None
150
+ self.PILmask = None
151
+ self.mask_draw_x = 0
152
+ self.mask_draw_y = 0
153
+ self.mask_draw_radius = 20
154
+ #self = master
155
+ #self.overrideredirect(True)
156
+ #self.title_bar = TitleBar(self)
157
+ #self.title_bar.pack(side="top", fill="x")
158
+ #make not user resizable
159
+ self.title("Caption Buddy")
160
+ #self.resizable(False, False)
161
+ self.geometry("720x820")
162
+ self.top_frame = ctk.CTkFrame(self,fg_color='transparent')
163
+ self.top_frame.pack(side="top", fill="x",expand=False)
164
+ self.top_subframe = ctk.CTkFrame(self.top_frame,fg_color='transparent')
165
+ self.top_subframe.pack(side="bottom", fill="x",pady=10)
166
+ self.top_subframe.grid_columnconfigure(0, weight=1)
167
+ self.top_subframe.grid_columnconfigure(1, weight=1)
168
+ self.tip_frame = ctk.CTkFrame(self,fg_color='transparent')
169
+ self.tip_frame.pack(side="top")
170
+ self.dark_mode_var = "#202020"
171
+ #self.dark_purple_mode_var = "#1B0F1B"
172
+ self.dark_mode_title_var = "#286aff"
173
+ self.dark_mode_button_pressed_var = "#BB91B6"
174
+ self.dark_mode_button_var = "#8ea0e1"
175
+ self.dark_mode_text_var = "#c6c7c8"
176
+ #self.configure(bg_color=self.dark_mode_var)
177
+ self.canvas = ctk.CTkLabel(self,text='', width=600, height=600)
178
+ #self.canvas.configure(bg_color=self.dark_mode_var)
179
+ #create temporary image for canvas
180
+ self.canvas.pack()
181
+ self.cur_img_index = 0
182
+ self.image_count = 0
183
+ #make a frame with a grid under the canvas
184
+ self.frame = ctk.CTkFrame(self)
185
+ #grid
186
+ self.frame.grid_columnconfigure(0, weight=1)
187
+ self.frame.grid_columnconfigure(1, weight=100)
188
+ self.frame.grid_columnconfigure(2, weight=1)
189
+ self.frame.grid_rowconfigure(0, weight=1)
190
+
191
+ #show the frame
192
+ self.frame.pack(side="bottom", fill="x")
193
+ #bottom frame
194
+ self.bottom_frame = ctk.CTkFrame(self)
195
+ #make grid
196
+ self.bottom_frame.grid_columnconfigure(0, weight=0)
197
+ self.bottom_frame.grid_columnconfigure(1, weight=2)
198
+ self.bottom_frame.grid_columnconfigure(2, weight=0)
199
+ self.bottom_frame.grid_columnconfigure(3, weight=2)
200
+ self.bottom_frame.grid_columnconfigure(4, weight=0)
201
+ self.bottom_frame.grid_columnconfigure(5, weight=2)
202
+ self.bottom_frame.grid_rowconfigure(0, weight=1)
203
+ #show the frame
204
+ self.bottom_frame.pack(side="bottom", fill="x")
205
+
206
+ self.image_index = 0
207
+ self.image_list = []
208
+ self.caption = ''
209
+ self.caption_file = ''
210
+ self.caption_file_path = ''
211
+ self.caption_file_name = ''
212
+ self.caption_file_ext = ''
213
+ self.caption_file_name_no_ext = ''
214
+ self.output_format='text'
215
+ #check if bad_files.txt exists
216
+ if os.path.exists("bad_files.txt"):
217
+ #delete it
218
+ os.remove("bad_files.txt")
219
+ self.use_blip = True
220
+ self.debug = False
221
+ self.create_widgets()
222
+ self.load_blip_model()
223
+ self.load_options()
224
+ #self.open_folder()
225
+
226
+ self.canvas.focus_force()
227
+ self.canvas.bind("<Alt-Right>", self.next_image)
228
+ self.canvas.bind("<Alt-Left>", self.prev_image)
229
+ #on close window
230
+ self.protocol("WM_DELETE_WINDOW", self.on_closing)
231
+ def on_closing(self):
232
+ #self.save_options()
233
+ self.mainProcess.deiconify()
234
+ self.destroy()
235
+ def create_widgets(self):
236
+ self.output_folder = ''
237
+
238
+ # add a checkbox to toggle auto generate caption
239
+ self.auto_generate_caption = tk.BooleanVar(self.top_subframe)
240
+ self.auto_generate_caption.set(True)
241
+ self.auto_generate_caption_checkbox = ctk.CTkCheckBox(self.top_subframe, text="Auto Generate Caption", variable=self.auto_generate_caption,width=50)
242
+ self.auto_generate_caption_checkbox.pack(side="left", fill="x", expand=True, padx=10)
243
+
244
+ # add a checkbox to skip auto generating captions if they already exist
245
+ self.auto_generate_caption_text_override = tk.BooleanVar(self.top_subframe)
246
+ self.auto_generate_caption_text_override.set(False)
247
+ self.auto_generate_caption_checkbox_text_override = ctk.CTkCheckBox(self.top_subframe, text="Skip Auto Generate If Text Caption Exists", variable=self.auto_generate_caption_text_override,width=50)
248
+ self.auto_generate_caption_checkbox_text_override.pack(side="left", fill="x", expand=True, padx=10)
249
+
250
+ # add a checkbox to enable mask editing
251
+ self.enable_mask_editing = tk.BooleanVar(self.top_subframe)
252
+ self.enable_mask_editing.set(False)
253
+ self.enable_mask_editing_checkbox = ctk.CTkCheckBox(self.top_subframe, text="Enable Mask Editing", variable=self.enable_mask_editing, width=50)
254
+ self.enable_mask_editing_checkbox.pack(side="left", fill="x", expand=True, padx=10)
255
+
256
+ self.open_button = ctk.CTkButton(self.top_frame,text="Load Folder",fg_color=("gray75", "gray25"), command=self.open_folder,width=50)
257
+ #self.open_button.grid(row=0, column=1)
258
+ self.open_button.pack(side="left", fill="x",expand=True,padx=10)
259
+ #add a batch folder button
260
+ self.batch_folder_caption_button = ctk.CTkButton(self.top_frame, text="Batch Folder Caption", fg_color=("gray75", "gray25"), command=self.batch_folder_caption, width=50)
261
+ self.batch_folder_caption_button.pack(side="left", fill="x", expand=True, padx=10)
262
+ self.batch_folder_mask_button = ctk.CTkButton(self.top_frame, text="Batch Folder Mask", fg_color=("gray75", "gray25"), command=self.batch_folder_mask, width=50)
263
+ self.batch_folder_mask_button.pack(side="left", fill="x", expand=True, padx=10)
264
+
265
+ #add an options button to the same row as the open button
266
+ self.options_button = ctk.CTkButton(self.top_frame, text="Options",fg_color=("gray75", "gray25"), command=self.open_options,width=50)
267
+ self.options_button.pack(side="left", fill="x",expand=True,padx=10)
268
+ #add generate caption button
269
+ self.generate_caption_button = ctk.CTkButton(self.top_frame, text="Generate Caption",fg_color=("gray75", "gray25"), command=self.generate_caption,width=50)
270
+ self.generate_caption_button.pack(side="left", fill="x",expand=True,padx=10)
271
+
272
+ #add a label for tips under the buttons
273
+ self.tips_label = ctk.CTkLabel(self.tip_frame, text="Use Alt with left and right arrow keys to navigate images, enter to save the caption.")
274
+ self.tips_label.pack(side="top")
275
+ #add image count label
276
+ self.image_count_label = ctk.CTkLabel(self.tip_frame, text=f"Image {self.cur_img_index} of {self.image_count}")
277
+ self.image_count_label.pack(side="top")
278
+
279
+ self.image_label = ctk.CTkLabel(self.canvas,text='',width=100,height=100)
280
+ self.image_label.grid(row=0, column=0, sticky="nsew")
281
+ #self.image_label.bind("<Button-3>", self.click_canvas)
282
+ self.image_label.bind("<Motion>", self.draw_mask)
283
+ self.image_label.bind("<Button-1>", self.draw_mask)
284
+ self.image_label.bind("<Button-3>", self.draw_mask)
285
+ self.image_label.bind("<MouseWheel>", self.draw_mask_radius)
286
+ #self.image_label.pack(side="top")
287
+ #previous button
288
+ self.prev_button = ctk.CTkButton(self.frame,text="Previous", command= lambda event=None: self.prev_image(event),width=50)
289
+ #grid
290
+ self.prev_button.grid(row=1, column=0, sticky="w",padx=5,pady=10)
291
+ #self.prev_button.pack(side="left")
292
+ #self.prev_button.bind("<Left>", self.prev_image)
293
+ self.caption_entry = ctk.CTkEntry(self.frame)
294
+ #grid
295
+ self.caption_entry.grid(row=1, column=1, rowspan=3, sticky="nsew",pady=10)
296
+ #bind to enter key
297
+ self.caption_entry.bind("<Return>", self.save)
298
+ self.canvas.bind("<Return>", self.save)
299
+ self.caption_entry.bind("<Alt-Right>", self.next_image)
300
+ self.caption_entry.bind("<Alt-Left>", self.prev_image)
301
+ self.caption_entry.bind("<Control-BackSpace>", self.delete_word)
302
+ #next button
303
+
304
+ self.next_button = ctk.CTkButton(self.frame,text='Next', command= lambda event=None: self.next_image(event),width=50)
305
+ #self.next_button["text"] = "Next"
306
+ #grid
307
+ self.next_button.grid(row=1, column=2, sticky="e",padx=5,pady=10)
308
+ #add two entry boxes and labels in the style of :replace _ with _
309
+ #create replace string variable
310
+ self.replace_label = ctk.CTkLabel(self.bottom_frame, text="Replace:")
311
+ self.replace_label.grid(row=0, column=0, sticky="w",padx=5)
312
+ self.replace_entry = ctk.CTkEntry(self.bottom_frame, )
313
+ self.replace_entry.grid(row=0, column=1, sticky="nsew",padx=5)
314
+ self.replace_entry.bind("<Return>", self.save)
315
+ #self.replace_entry.bind("<Tab>", self.replace)
316
+ #with label
317
+ #create with string variable
318
+ self.with_label = ctk.CTkLabel(self.bottom_frame, text="With:")
319
+ self.with_label.grid(row=0, column=2, sticky="w",padx=5)
320
+ self.with_entry = ctk.CTkEntry(self.bottom_frame, )
321
+ self.with_entry.grid(row=0, column=3, sticky="nswe",padx=5)
322
+ self.with_entry.bind("<Return>", self.save)
323
+ #add another entry with label, add suffix
324
+
325
+ #create prefix string var
326
+ self.prefix_label = ctk.CTkLabel(self.bottom_frame, text="Add to start:")
327
+ self.prefix_label.grid(row=0, column=4, sticky="w",padx=5)
328
+ self.prefix_entry = ctk.CTkEntry(self.bottom_frame, )
329
+ self.prefix_entry.grid(row=0, column=5, sticky="nsew",padx=5)
330
+ self.prefix_entry.bind("<Return>", self.save)
331
+
332
+ #create suffix string var
333
+ self.suffix_label = ctk.CTkLabel(self.bottom_frame, text="Add to end:")
334
+ self.suffix_label.grid(row=0, column=6, sticky="w",padx=5)
335
+ self.suffix_entry = ctk.CTkEntry(self.bottom_frame, )
336
+ self.suffix_entry.grid(row=0, column=7, sticky="nsew",padx=5)
337
+ self.suffix_entry.bind("<Return>", self.save)
338
+ self.all_entries = [self.replace_entry, self.with_entry, self.suffix_entry, self.caption_entry, self.prefix_entry]
339
+ #bind right click menu to all entries
340
+ for entry in self.all_entries:
341
+ entry.bind("<Button-3>", self.create_right_click_menu)
342
+ def batch_folder_caption(self):
343
+ #show imgs in folder askdirectory
344
+ #ask user if to batch current folder or select folder
345
+ #if bad_files.txt exists, delete it
346
+ self.bad_files = []
347
+ if os.path.exists('bad_files.txt'):
348
+ os.remove('bad_files.txt')
349
+ try:
350
+ #check if self.folder is set
351
+ self.folder
352
+ except AttributeError:
353
+ self.folder = ''
354
+ if self.folder == '':
355
+ self.folder = fd.askdirectory(title="Select Folder to Batch Process", initialdir=os.getcwd())
356
+ batch_input_dir = self.folder
357
+ else:
358
+ ask = tk.messagebox.askquestion("Batch Folder", "Batch current folder?")
359
+ if ask == 'yes':
360
+ batch_input_dir = self.folder
361
+ else:
362
+ batch_input_dir = fd.askdirectory(title="Select Folder to Batch Process", initialdir=os.getcwd())
363
+ ask2 = tk.messagebox.askquestion("Batch Folder", "Save output to same directory?")
364
+ if ask2 == 'yes':
365
+ batch_output_dir = batch_input_dir
366
+ else:
367
+ batch_output_dir = fd.askdirectory(title="Select Folder to Save Batch Processed Images", initialdir=os.getcwd())
368
+ if batch_input_dir == '':
369
+ return
370
+ if batch_output_dir == '':
371
+ batch_output_dir = batch_input_dir
372
+
373
+ self.caption_file_name = os.path.basename(batch_input_dir)
374
+ self.image_list = []
375
+ for file in os.listdir(batch_input_dir):
376
+ if _check_file_type(file) and not file.endswith('-masklabel.png'):
377
+ self.image_list.append(os.path.join(batch_input_dir, file))
378
+ self.image_index = 0
379
+ #use progress bar class
380
+ #pba = tk.Tk()
381
+ #pba.title("Batch Processing")
382
+ #remove icon
383
+ #pba.wm_attributes('-toolwindow','True')
384
+ pb = ProgressbarWithCancel(max=len(self.image_list))
385
+ #pb.set_max(len(self.image_list))
386
+ pb.set_progress(0)
387
+
388
+ #if batch_output_dir doesn't exist, create it
389
+ if not os.path.exists(batch_output_dir):
390
+ os.makedirs(batch_output_dir)
391
+ for i in range(len(self.image_list)):
392
+ radnom_chance = random.randint(0,25)
393
+ if radnom_chance == 0:
394
+ pb.set_random_label()
395
+ if pb.is_cancelled():
396
+ pb.destroy()
397
+ return
398
+ self.image_index = i
399
+ #get float value of progress between 0 and 1 according to the image index and the total number of images
400
+ progress = i / len(self.image_list)
401
+ pb.set_progress(progress)
402
+ self.update()
403
+ try:
404
+ img = Image.open(self.image_list[i]).convert("RGB")
405
+ except:
406
+ self.bad_files.append(self.image_list[i])
407
+ #skip file
408
+ continue
409
+ tensor = transforms.Compose([
410
+ transforms.Resize((self.blipSize, self.blipSize), interpolation=InterpolationMode.BICUBIC),
411
+ transforms.ToTensor(),
412
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
413
+ ])
414
+ torch_image = tensor(img).unsqueeze(0).to(torch.device("cuda"))
415
+ if self.nucleus_sampling:
416
+ captions = self.blip_decoder.generate(torch_image, sample=True, top_p=self.q_factor)
417
+ else:
418
+ captions = self.blip_decoder.generate(torch_image, sample=False, num_beams=16, min_length=self.min_length, \
419
+ max_length=48, repetition_penalty=self.q_factor)
420
+ caption = captions[0]
421
+ self.replace = self.replace_entry.get()
422
+ self.replace_with = self.with_entry.get()
423
+ self.suffix_var = self.suffix_entry.get()
424
+ self.prefix = self.prefix_entry.get()
425
+ #prepare the caption
426
+ if self.suffix_var.startswith(',') or self.suffix_var.startswith(' '):
427
+ self.suffix_var = self.suffix_var
428
+ else:
429
+ self.suffix_var = ' ' + self.suffix_var
430
+ caption = caption.replace(self.replace, self.replace_with)
431
+ if self.prefix != '':
432
+ if self.prefix.endswith(' '):
433
+ self.prefix = self.prefix[:-1]
434
+ if not self.prefix.endswith(','):
435
+ self.prefix = self.prefix+','
436
+ caption = self.prefix + ' ' + caption
437
+ if caption.endswith(',') or caption.endswith('.'):
438
+ caption = caption[:-1]
439
+ caption = caption +', ' + self.suffix_var
440
+ else:
441
+ caption = caption + self.suffix_var
442
+ #saving the captioned image
443
+ if self.output_format == 'text':
444
+ #text file with same name as image
445
+ imgName = os.path.basename(self.image_list[self.image_index])
446
+ imgName = imgName[:imgName.rfind('.')]
447
+ caption_file = os.path.join(batch_output_dir, imgName + '.txt')
448
+ with open(caption_file, 'w') as f:
449
+ f.write(caption)
450
+ elif self.output_format == 'filename':
451
+ #duplicate image with caption as file name
452
+ img.save(os.path.join(batch_output_dir, caption+'.png'))
453
+ progress = i + 1 / len(self.image_list)
454
+ pb.set_progress(progress)
455
+ #show message box when done
456
+ pb.destroy()
457
+ donemsg = tk.messagebox.showinfo("Batch Folder", "Batching complete!",parent=self.master)
458
+ if len(self.bad_files) > 0:
459
+ bad_files_msg = tk.messagebox.showinfo("Bad Files", "Couldn't process " + str(len(self.bad_files)) + "files,\nFor a list of problematic files see bad_files.txt",parent=self.master)
460
+ with open('bad_files.txt', 'w') as f:
461
+ for item in self.bad_files:
462
+ f.write(item + '\n')
463
+
464
+ #ask user if we should load the batch output folder
465
+ ask3 = tk.messagebox.askquestion("Batch Folder", "Load batch output folder?")
466
+ if ask3 == 'yes':
467
+ self.image_index = 0
468
+ self.open_folder(folder=batch_output_dir)
469
+ #focus on donemsg
470
+ #donemsg.focus_force()
471
+ def generate_caption(self):
472
+ #get the image
473
+ tensor = transforms.Compose([
474
+ #transforms.CenterCrop(SIZE),
475
+ transforms.Resize((self.blipSize, self.blipSize), interpolation=InterpolationMode.BICUBIC),
476
+ transforms.ToTensor(),
477
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
478
+ ])
479
+ torch_image = tensor(self.PILimage).unsqueeze(0).to(torch.device("cuda"))
480
+ if self.nucleus_sampling:
481
+ captions = self.blip_decoder.generate(torch_image, sample=True, top_p=self.q_factor)
482
+ else:
483
+ captions = self.blip_decoder.generate(torch_image, sample=False, num_beams=16, min_length=self.min_length, \
484
+ max_length=48, repetition_penalty=self.q_factor)
485
+ self.caption = captions[0]
486
+ self.caption_entry.delete(0, tk.END)
487
+ self.caption_entry.insert(0, self.caption)
488
+ #change the caption entry color to red
489
+ self.caption_entry.configure(fg_color='red')
490
+ def load_blip_model(self):
491
+ self.blipSize = 384
492
+ blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
493
+ #check if options file exists
494
+ if os.path.exists(os.path.join(self.captioner_folder, 'options.json')):
495
+ with open(os.path.join(self.captioner_folder, 'options.json'), 'r') as f:
496
+ self.nucleus_sampling = json.load(f)['nucleus_sampling']
497
+ self.q_factor = json.load(f)['q_factor']
498
+ self.min_length = json.load(f)['min_length']
499
+ else:
500
+ self.nucleus_sampling = False
501
+ self.q_factor = 1.0
502
+ self.min_length = 22
503
+ config_path = os.path.join(self.captioner_folder, "BLIP/configs/med_config.json")
504
+ cache_folder = os.path.join(self.captioner_folder, "BLIP/cache")
505
+ model_path = os.path.join(self.captioner_folder, "BLIP/models/model_base_caption_capfilt_large.pth")
506
+ if not os.path.exists(cache_folder):
507
+ os.makedirs(cache_folder)
508
+
509
+ if not os.path.exists(model_path):
510
+ print(f"Downloading BLIP to {cache_folder}")
511
+ with requests.get(blip_model_url, stream=True) as session:
512
+ session.raise_for_status()
513
+ with open(model_path, 'wb') as f:
514
+ for chunk in session.iter_content(chunk_size=1024):
515
+ f.write(chunk)
516
+ print('Download complete')
517
+ else:
518
+ print(f"Found BLIP model")
519
+ import models.blip
520
+ blip_decoder = models.blip.blip_decoder(pretrained=model_path, image_size=self.blipSize, vit='base', med_config=config_path)
521
+ blip_decoder.eval()
522
+ self.blip_decoder = blip_decoder.to(torch.device("cuda"))
523
+
524
+ def batch_folder_mask(self):
525
+ folder = ''
526
+ try:
527
+ # check if self.folder is set
528
+ folder = self.folder
529
+ except:
530
+ pass
531
+
532
+ dialog = BatchMaskWindow(self, folder)
533
+ dialog.mainloop()
534
+
535
+ def load_clip_seg_model(self):
536
+ if self.clip_seg is None:
537
+ self.clip_seg = ClipSeg()
538
+
539
+ def open_folder(self,folder=None):
540
+ if folder is None:
541
+ self.folder = fd.askdirectory()
542
+ else:
543
+ self.folder = folder
544
+ if self.folder == '':
545
+ return
546
+ self.output_folder = self.folder
547
+ self.image_list = [os.path.join(self.folder, f) for f in os.listdir(self.folder) if _check_file_type(f) and not f.endswith('-masklabel.png') and not f.endswith('-depth.png')]
548
+ #self.image_list.sort()
549
+ #sort the image list alphabetically so that the images are in the same order every time
550
+ self.image_list.sort(key=lambda x: x.lower())
551
+
552
+ self.image_count = len(self.image_list)
553
+ if self.image_count == 0:
554
+ tk.messagebox.showinfo("No Images", "No images found in the selected folder")
555
+ return
556
+ #update the image count label
557
+
558
+ self.image_index = 0
559
+ self.image_count_label.configure(text=f'Image {self.image_index+1} of {self.image_count}')
560
+ self.output_folder = self.folder
561
+ self.load_image()
562
+ self.caption_entry.focus_set()
563
+
564
+ def draw_mask(self, event):
565
+ if not self.enable_mask_editing.get():
566
+ return
567
+
568
+ if event.widget != self.image_label.children["!label"]:
569
+ return
570
+
571
+ start_x = int(event.x / self.image_size[0] * self.PILimage.width)
572
+ start_y = int(event.y / self.image_size[1] * self.PILimage.height)
573
+ end_x = int(self.mask_draw_x / self.image_size[0] * self.PILimage.width)
574
+ end_y = int(self.mask_draw_y / self.image_size[1] * self.PILimage.height)
575
+
576
+ self.mask_draw_x = event.x
577
+ self.mask_draw_y = event.y
578
+
579
+ color = None
580
+
581
+ if event.state & 0x0100 or event.num == 1: # left mouse button
582
+ color = (255, 255, 255)
583
+ elif event.state & 0x0400 or event.num == 3: # right mouse button
584
+ color = (0, 0, 0)
585
+
586
+ if color is not None:
587
+ if self.PILmask is None:
588
+ self.PILmask = Image.new('RGB', size=self.PILimage.size, color=(0, 0, 0))
589
+
590
+ draw = ImageDraw.Draw(self.PILmask)
591
+ draw.line((start_x, start_y, end_x, end_y), fill=color, width=self.mask_draw_radius + self.mask_draw_radius + 1)
592
+ draw.ellipse((start_x - self.mask_draw_radius, start_y - self.mask_draw_radius, start_x + self.mask_draw_radius, start_y + self.mask_draw_radius), fill=color, outline=None)
593
+ draw.ellipse((end_x - self.mask_draw_radius, end_y - self.mask_draw_radius, end_x + self.mask_draw_radius, end_y + self.mask_draw_radius), fill=color, outline=None)
594
+
595
+ self.compose_masked_image()
596
+ self.display_image()
597
+
598
+ def draw_mask_radius(self, event):
599
+ if event.widget != self.image_label.children["!label"]:
600
+ return
601
+
602
+ delta = -np.sign(event.delta) * 5
603
+ self.mask_draw_radius += delta
604
+
605
+ def compose_masked_image(self):
606
+ np_image = np.array(self.PILimage).astype(np.float32) / 255.0
607
+ np_mask = np.array(self.PILmask).astype(np.float32) / 255.0
608
+ np_mask = np.clip(np_mask, 0.4, 1.0)
609
+ np_masked_image = (np_image * np_mask * 255.0).astype(np.uint8)
610
+ self.image = Image.fromarray(np_masked_image, mode='RGB')
611
+
612
+ def display_image(self):
613
+ #resize to fit 600x600 while maintaining aspect ratio
614
+ width, height = self.image.size
615
+ if width > height:
616
+ new_width = 600
617
+ new_height = int(600 * height / width)
618
+ else:
619
+ new_height = 600
620
+ new_width = int(600 * width / height)
621
+ self.image_size = (new_width, new_height)
622
+ self.image = self.image.resize(self.image_size, Image.Resampling.LANCZOS)
623
+ self.image = ctk.CTkImage(self.image, size=self.image_size)
624
+ self.image_label.configure(image=self.image)
625
+
626
+ def load_image(self):
627
+ try:
628
+ self.PILimage = Image.open(self.image_list[self.image_index]).convert('RGB')
629
+ except:
630
+ print(f'Error opening image {self.image_list[self.image_index]}')
631
+ print('Logged path to bad_files.txt')
632
+ #if bad_files.txt doesn't exist, create it
633
+ if not os.path.exists('bad_files.txt'):
634
+ with open('bad_files.txt', 'w') as f:
635
+ f.write(self.image_list[self.image_index]+'\n')
636
+ else:
637
+ with open('bad_files.txt', 'a') as f:
638
+ f.write(self.image_list[self.image_index]+'\n')
639
+ return
640
+
641
+ self.image = self.PILimage.copy()
642
+
643
+ try:
644
+ self.PILmask = None
645
+ mask_filename = os.path.splitext(self.image_list[self.image_index])[0] + '-masklabel.png'
646
+ if os.path.exists(mask_filename):
647
+ self.PILmask = Image.open(mask_filename).convert('RGB')
648
+ self.compose_masked_image()
649
+ except Exception as e:
650
+ print(f'Error opening mask for {self.image_list[self.image_index]}')
651
+ print('Logged path to bad_files.txt')
652
+ #if bad_files.txt doesn't exist, create it
653
+ if not os.path.exists('bad_files.txt'):
654
+ with open('bad_files.txt', 'w') as f:
655
+ f.write(self.image_list[self.image_index]+'\n')
656
+ else:
657
+ with open('bad_files.txt', 'a') as f:
658
+ f.write(self.image_list[self.image_index]+'\n')
659
+ return
660
+
661
+ self.display_image()
662
+
663
+ self.caption_file_path = self.image_list[self.image_index]
664
+ self.caption_file_name = os.path.basename(self.caption_file_path)
665
+ self.caption_file_ext = os.path.splitext(self.caption_file_name)[1]
666
+ self.caption_file_name_no_ext = os.path.splitext(self.caption_file_name)[0]
667
+ self.caption_file = os.path.join(self.folder, self.caption_file_name_no_ext + '.txt')
668
+ if os.path.isfile(self.caption_file) and self.auto_generate_caption.get() == False or os.path.isfile(self.caption_file) and self.auto_generate_caption.get() == True and self.auto_generate_caption_text_override.get() == True:
669
+ with open(self.caption_file, 'r') as f:
670
+ self.caption = f.read()
671
+ self.caption_entry.delete(0, tk.END)
672
+ self.caption_entry.insert(0, self.caption)
673
+ self.caption_entry.configure(fg_color=ThemeManager.theme["CTkEntry"]["fg_color"])
674
+ self.use_blip = False
675
+ elif os.path.isfile(self.caption_file) and self.auto_generate_caption.get() == True and self.auto_generate_caption_text_override.get() == False or os.path.isfile(self.caption_file)==False and self.auto_generate_caption.get() == True and self.auto_generate_caption_text_override.get() == True:
676
+ self.use_blip = True
677
+ self.caption_entry.delete(0, tk.END)
678
+ elif os.path.isfile(self.caption_file) == False and self.auto_generate_caption.get() == False:
679
+ self.caption_entry.delete(0, tk.END)
680
+ return
681
+ if self.use_blip and self.debug==False:
682
+ tensor = transforms.Compose([
683
+ transforms.Resize((self.blipSize, self.blipSize), interpolation=InterpolationMode.BICUBIC),
684
+ transforms.ToTensor(),
685
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
686
+ ])
687
+ torch_image = tensor(self.PILimage).unsqueeze(0).to(torch.device("cuda"))
688
+ if self.nucleus_sampling:
689
+ captions = self.blip_decoder.generate(torch_image, sample=True, top_p=self.q_factor)
690
+ else:
691
+ captions = self.blip_decoder.generate(torch_image, sample=False, num_beams=16, min_length=self.min_length, \
692
+ max_length=48, repetition_penalty=self.q_factor)
693
+ self.caption = captions[0]
694
+ self.caption_entry.delete(0, tk.END)
695
+ self.caption_entry.insert(0, self.caption)
696
+ #change the caption entry color to red
697
+ self.caption_entry.configure(fg_color='red')
698
+
699
+ def save(self, event):
700
+ self.save_caption()
701
+
702
+ if self.enable_mask_editing.get():
703
+ self.save_mask()
704
+
705
+ def save_mask(self):
706
+ mask_filename = os.path.splitext(self.image_list[self.image_index])[0] + '-masklabel.png'
707
+ if self.PILmask is not None:
708
+ self.PILmask.save(mask_filename)
709
+
710
+ def save_caption(self):
711
+ self.caption = self.caption_entry.get()
712
+ self.replace = self.replace_entry.get()
713
+ self.replace_with = self.with_entry.get()
714
+ self.suffix_var = self.suffix_entry.get()
715
+ self.prefix = self.prefix_entry.get()
716
+ #prepare the caption
717
+ self.caption = self.caption.replace(self.replace, self.replace_with)
718
+ if self.suffix_var.startswith(',') or self.suffix_var.startswith(' '):
719
+ self.suffix_var = self.suffix_var
720
+ else:
721
+ self.suffix_var = ' ' + self.suffix_var
722
+ if self.prefix != '':
723
+ if self.prefix.endswith(' '):
724
+ self.prefix = self.prefix[:-1]
725
+ if not self.prefix.endswith(','):
726
+ self.prefix = self.prefix+','
727
+ self.caption = self.prefix + ' ' + self.caption
728
+ if self.caption.endswith(',') or self.caption.endswith('.'):
729
+ self.caption = self.caption[:-1]
730
+ self.caption = self.caption +', ' + self.suffix_var
731
+ else:
732
+ self.caption = self.caption + self.suffix_var
733
+ self.caption = self.caption.strip()
734
+ if self.output_folder != self.folder:
735
+ outputFolder = self.output_folder
736
+ else:
737
+ outputFolder = self.folder
738
+ if self.output_format == 'text':
739
+ #text file with same name as image
740
+ #image name
741
+ #print('test')
742
+ imgName = os.path.basename(self.image_list[self.image_index])
743
+ imgName = imgName[:imgName.rfind('.')]
744
+ self.caption_file = os.path.join(outputFolder, imgName + '.txt')
745
+ with open(self.caption_file, 'w') as f:
746
+ f.write(self.caption)
747
+ elif self.output_format == 'filename':
748
+ #duplicate image with caption as file name
749
+ #make sure self.caption doesn't contain any illegal characters
750
+ illegal_chars = ['/', '\\', ':', '*', '?', '"', "'",'<', '>', '|', '.']
751
+ for char in illegal_chars:
752
+ self.caption = self.caption.replace(char, '')
753
+ self.PILimage.save(os.path.join(outputFolder, self.caption+'.png'))
754
+ self.caption_entry.delete(0, tk.END)
755
+ self.caption_entry.insert(0, self.caption)
756
+ self.caption_entry.configure(fg_color='green')
757
+
758
+ self.caption_entry.focus_force()
759
+ def delete_word(self,event):
760
+ ent = event.widget
761
+ end_idx = ent.index(tk.INSERT)
762
+ start_idx = ent.get().rfind(" ", None, end_idx)
763
+ ent.selection_range(start_idx, end_idx)
764
+ def prev_image(self, event):
765
+ if self.image_index > 0:
766
+ self.image_index -= 1
767
+ self.image_count_label.configure(text=f'Image {self.image_index+1} of {self.image_count}')
768
+ self.load_image()
769
+ self.caption_entry.focus_set()
770
+ self.caption_entry.focus_force()
771
+ def next_image(self, event):
772
+ if self.image_index < len(self.image_list) - 1:
773
+ self.image_index += 1
774
+ self.image_count_label.configure(text=f'Image {self.image_index+1} of {self.image_count}')
775
+ self.load_image()
776
+ self.caption_entry.focus_set()
777
+ self.caption_entry.focus_force()
778
+ def open_options(self):
779
+ self.options_window = ctk.CTkToplevel(self)
780
+ self.options_window.title("Options")
781
+ self.options_window.geometry("320x550")
782
+ #disable reszie
783
+ self.options_window.resizable(False, False)
784
+ self.options_window.focus_force()
785
+ self.options_window.grab_set()
786
+ self.options_window.transient(self)
787
+ self.options_window.protocol("WM_DELETE_WINDOW", self.close_options)
788
+ #add title label
789
+ self.options_title_label = ctk.CTkLabel(self.options_window, text="Options",font=ctk.CTkFont(size=20, weight="bold"))
790
+ self.options_title_label.pack(side="top", pady=5)
791
+ #add an entry with a button to select a folder as output folder
792
+ self.output_folder_label = ctk.CTkLabel(self.options_window, text="Output Folder")
793
+ self.output_folder_label.pack(side="top", pady=5)
794
+ self.output_folder_entry = ctk.CTkEntry(self.options_window)
795
+ self.output_folder_entry.pack(side="top", fill="x", expand=False,padx=15, pady=5)
796
+ self.output_folder_entry.insert(0, self.output_folder)
797
+ self.output_folder_button = ctk.CTkButton(self.options_window, text="Select Folder", command=self.select_output_folder,fg_color=("gray75", "gray25"))
798
+ self.output_folder_button.pack(side="top", pady=5)
799
+ #add radio buttons to select the output format between text and filename
800
+ self.output_format_label = ctk.CTkLabel(self.options_window, text="Output Format")
801
+ self.output_format_label.pack(side="top", pady=5)
802
+ self.output_format_var = tk.StringVar(self.options_window)
803
+ self.output_format_var.set(self.output_format)
804
+ self.output_format_text = ctk.CTkRadioButton(self.options_window, text="Text File", variable=self.output_format_var, value="text")
805
+ self.output_format_text.pack(side="top", pady=5)
806
+ self.output_format_filename = ctk.CTkRadioButton(self.options_window, text="File name", variable=self.output_format_var, value="filename")
807
+ self.output_format_filename.pack(side="top", pady=5)
808
+ #add BLIP settings section
809
+ self.blip_settings_label = ctk.CTkLabel(self.options_window, text="BLIP Settings",font=ctk.CTkFont(size=20, weight="bold"))
810
+ self.blip_settings_label.pack(side="top", pady=10)
811
+ #add a checkbox to use nucleas sampling or not
812
+ self.nucleus_sampling_var = tk.IntVar(self.options_window)
813
+ self.nucleus_sampling_checkbox = ctk.CTkCheckBox(self.options_window, text="Use nucleus sampling", variable=self.nucleus_sampling_var)
814
+ self.nucleus_sampling_checkbox.pack(side="top", pady=5)
815
+ if self.debug:
816
+ self.nucleus_sampling = 0
817
+ self.q_factor = 0.5
818
+ self.min_length = 10
819
+ self.nucleus_sampling_var.set(self.nucleus_sampling)
820
+ #add a float entry to set the q factor
821
+ self.q_factor_label = ctk.CTkLabel(self.options_window, text="Q Factor")
822
+ self.q_factor_label.pack(side="top", pady=5)
823
+ self.q_factor_entry = ctk.CTkEntry(self.options_window)
824
+ self.q_factor_entry.insert(0, self.q_factor)
825
+ self.q_factor_entry.pack(side="top", pady=5)
826
+ #add a int entry to set the number minimum length
827
+ self.min_length_label = ctk.CTkLabel(self.options_window, text="Minimum Length")
828
+ self.min_length_label.pack(side="top", pady=5)
829
+ self.min_length_entry = ctk.CTkEntry(self.options_window)
830
+ self.min_length_entry.insert(0, self.min_length)
831
+ self.min_length_entry.pack(side="top", pady=5)
832
+ #add a horozontal radio button to select between None, ViT-L-14/openai, ViT-H-14/laion2b_s32b_b79k
833
+ #self.model_label = ctk.CTkLabel(self.options_window, text="CLIP Interrogation")
834
+ #self.model_label.pack(side="top")
835
+ #self.model_var = tk.StringVar(self.options_window)
836
+ #self.model_var.set(self.model)
837
+ #self.model_none = tk.Radiobutton(self.options_window, text="None", variable=self.model_var, value="None")
838
+ #self.model_none.pack(side="top")
839
+ #self.model_vit_l_14 = tk.Radiobutton(self.options_window, text="ViT-L-14/openai", variable=self.model_var, value="ViT-L-14/openai")
840
+ #self.model_vit_l_14.pack(side="top")
841
+ #self.model_vit_h_14 = tk.Radiobutton(self.options_window, text="ViT-H-14/laion2b_s32b_b79k", variable=self.model_var, value="ViT-H-14/laion2b_s32b_b79k")
842
+ #self.model_vit_h_14.pack(side="top")
843
+
844
+ #add a save button
845
+ self.save_button = ctk.CTkButton(self.options_window, text="Save", command=self.save_options, fg_color=("gray75", "gray25"))
846
+ self.save_button.pack(side="top",fill='x',pady=10,padx=10)
847
+ #all entries list
848
+ entries = [self.output_folder_entry, self.q_factor_entry, self.min_length_entry]
849
+ #bind the right click to all entries
850
+ for entry in entries:
851
+ entry.bind("<Button-3>", self.create_right_click_menu)
852
+ self.options_file = os.path.join(self.captioner_folder, 'captioner_options.json')
853
+ if os.path.isfile(self.options_file):
854
+ with open(self.options_file, 'r') as f:
855
+ self.options = json.load(f)
856
+ self.output_folder_entry.delete(0, tk.END)
857
+ self.output_folder_entry.insert(0, self.output_folder)
858
+ self.output_format_var.set(self.options['output_format'])
859
+ self.nucleus_sampling_var.set(self.options['nucleus_sampling'])
860
+ self.q_factor_entry.delete(0, tk.END)
861
+ self.q_factor_entry.insert(0, self.options['q_factor'])
862
+ self.min_length_entry.delete(0, tk.END)
863
+ self.min_length_entry.insert(0, self.options['min_length'])
864
+ def load_options(self):
865
+ self.options_file = os.path.join(self.captioner_folder, 'captioner_options.json')
866
+ if os.path.isfile(self.options_file):
867
+ with open(self.options_file, 'r') as f:
868
+ self.options = json.load(f)
869
+ #self.output_folder = self.folder
870
+ #self.output_folder = self.options['output_folder']
871
+ if 'folder' in self.__dict__:
872
+ self.output_folder = self.folder
873
+ else:
874
+ self.output_folder = ''
875
+ self.output_format = self.options['output_format']
876
+ self.nucleus_sampling = self.options['nucleus_sampling']
877
+ self.q_factor = self.options['q_factor']
878
+ self.min_length = self.options['min_length']
879
+ else:
880
+ #if self has folder, use it, otherwise use the current folder
881
+ if 'folder' in self.__dict__ :
882
+ self.output_folder = self.folder
883
+ else:
884
+ self.output_folder = ''
885
+ self.output_format = "text"
886
+ self.nucleus_sampling = False
887
+ self.q_factor = 0.9
888
+ self.min_length =22
889
+ def save_options(self):
890
+ self.output_folder = self.output_folder_entry.get()
891
+ self.output_format = self.output_format_var.get()
892
+ self.nucleus_sampling = self.nucleus_sampling_var.get()
893
+ self.q_factor = float(self.q_factor_entry.get())
894
+ self.min_length = int(self.min_length_entry.get())
895
+ #save options to a file
896
+ self.options_file = os.path.join(self.captioner_folder, 'captioner_options.json')
897
+ with open(self.options_file, 'w') as f:
898
+ json.dump({'output_folder': self.output_folder, 'output_format': self.output_format, 'nucleus_sampling': self.nucleus_sampling, 'q_factor': self.q_factor, 'min_length': self.min_length}, f)
899
+ self.close_options()
900
+
901
+ def select_output_folder(self):
902
+ self.output_folder = fd.askdirectory()
903
+ self.output_folder_entry.delete(0, tk.END)
904
+ self.output_folder_entry.insert(0, self.output_folder)
905
+ def close_options(self):
906
+ self.options_window.destroy()
907
+ self.caption_entry.focus_force()
908
+ def create_right_click_menu(self, event):
909
+ #create a menu
910
+ self.menu = Menu(self, tearoff=0)
911
+ #add commands to the menu
912
+ self.menu.add_command(label="Cut", command=lambda: self.focus_get().event_generate("<<Cut>>"))
913
+ self.menu.add_command(label="Copy", command=lambda: self.focus_get().event_generate("<<Copy>>"))
914
+ self.menu.add_command(label="Paste", command=lambda: self.focus_get().event_generate("<<Paste>>"))
915
+ self.menu.add_command(label="Select All", command=lambda: self.focus_get().event_generate("<<SelectAll>>"))
916
+ #display the menu
917
+ try:
918
+ self.menu.tk_popup(event.x_root, event.y_root)
919
+ finally:
920
+ #make sure to release the grab (Tk 8.0a1 only)
921
+ self.menu.grab_release()
922
+
923
+
924
+ #progress bar class with cancel button
925
+ class ProgressbarWithCancel(ctk.CTkToplevel):
926
+ def __init__(self,max=None, **kw):
927
+ super().__init__(**kw)
928
+ self.title("Batching...")
929
+ self.max = max
930
+ self.possibleLabels = ['Searching for answers...',"I'm working, I promise.",'ARE THOSE TENTACLES?!','Weird data man...','Another one bites the dust' ,"I think it's a cat?" ,'Looking for the meaning of life', 'Dreaming of captions']
931
+
932
+ self.label = ctk.CTkLabel(self, text="Searching for answers...")
933
+ self.label.pack(side="top", fill="x", expand=True,padx=10,pady=10)
934
+ self.progress = ctk.CTkProgressBar(self, orientation="horizontal", mode="determinate")
935
+ self.progress.pack(side="left", fill="x", expand=True,padx=10,pady=10)
936
+ self.cancel_button = ctk.CTkButton(self, text="Cancel", command=self.cancel)
937
+ self.cancel_button.pack(side="right",padx=10,pady=10)
938
+ self.cancelled = False
939
+ self.count_label = ctk.CTkLabel(self, text="0/{0}".format(self.max))
940
+ self.count_label.pack(side="right",padx=10,pady=10)
941
+ def set_random_label(self):
942
+ import random
943
+ self.label["text"] = random.choice(self.possibleLabels)
944
+ #pop from list
945
+ #self.possibleLabels.remove(self.label["text"])
946
+ def cancel(self):
947
+ self.cancelled = True
948
+ def set_progress(self, value):
949
+ self.progress.set(value)
950
+ self.count_label.configure(text="{0}/{1}".format(int(value * self.max), self.max))
951
+ def get_progress(self):
952
+ return self.progress.get
953
+ def set_max(self, value):
954
+ return value
955
+ def get_max(self):
956
+ return self.progress["maximum"]
957
+ def is_cancelled(self):
958
+ return self.cancelled
959
+ #quit the progress bar window
960
+
961
+
962
+ #run when imported as a module
963
+ if __name__ == "__main__":
964
+
965
+ #root = tk.Tk()
966
+ app = ImageBrowser()
967
+ app.mainloop()
StableTuner_RunPod_Fix/clip_segmentation.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import Optional, Callable
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from torch import Tensor, nn
8
+ from torchvision.transforms import transforms, functional
9
+ from tqdm.auto import tqdm
10
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
11
+
12
+ DEVICE = "cuda"
13
+
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser(description="ClipSeg script.")
17
+ parser.add_argument(
18
+ "--sample_dir",
19
+ type=str,
20
+ required=True,
21
+ help="directory where samples are located",
22
+ )
23
+ parser.add_argument(
24
+ "--add_prompt",
25
+ type=str,
26
+ required=True,
27
+ action="append",
28
+ help="a prompt used to create a mask",
29
+ dest="prompts",
30
+ )
31
+ parser.add_argument(
32
+ "--mode",
33
+ type=str,
34
+ default='fill',
35
+ required=False,
36
+ help="Either replace, fill, add or subtract",
37
+ )
38
+ parser.add_argument(
39
+ "--threshold",
40
+ type=float,
41
+ default='0.3',
42
+ required=False,
43
+ help="threshold for including pixels in the mask",
44
+ )
45
+ parser.add_argument(
46
+ "--smooth_pixels",
47
+ type=int,
48
+ default=5,
49
+ required=False,
50
+ help="radius of a smoothing operation applied to the generated mask",
51
+ )
52
+ parser.add_argument(
53
+ "--expand_pixels",
54
+ type=int,
55
+ default=10,
56
+ required=False,
57
+ help="amount of expansion of the generated mask in all directions",
58
+ )
59
+
60
+ args = parser.parse_args()
61
+ return args
62
+
63
+
64
+ class MaskSample:
65
+ def __init__(self, filename: str):
66
+ self.image_filename = filename
67
+ self.mask_filename = os.path.splitext(filename)[0] + "-masklabel.png"
68
+
69
+ self.image = None
70
+ self.mask_tensor = None
71
+
72
+ self.height = 0
73
+ self.width = 0
74
+
75
+ self.image2Tensor = transforms.Compose([
76
+ transforms.ToTensor(),
77
+ ])
78
+
79
+ self.tensor2Image = transforms.Compose([
80
+ transforms.ToPILImage(),
81
+ ])
82
+
83
+ def get_image(self) -> Image:
84
+ if self.image is None:
85
+ self.image = Image.open(self.image_filename).convert('RGB')
86
+ self.height = self.image.height
87
+ self.width = self.image.width
88
+
89
+ return self.image
90
+
91
+ def get_mask_tensor(self) -> Tensor:
92
+ if self.mask_tensor is None and os.path.exists(self.mask_filename):
93
+ mask = Image.open(self.mask_filename).convert('L')
94
+ mask = self.image2Tensor(mask)
95
+ mask = mask.to(DEVICE)
96
+ self.mask_tensor = mask.unsqueeze(0)
97
+
98
+ return self.mask_tensor
99
+
100
+ def set_mask_tensor(self, mask_tensor: Tensor):
101
+ self.mask_tensor = mask_tensor
102
+
103
+ def add_mask_tensor(self, mask_tensor: Tensor):
104
+ mask = self.get_mask_tensor()
105
+ if mask is None:
106
+ mask = mask_tensor
107
+ else:
108
+ mask += mask_tensor
109
+ mask = torch.clamp(mask, 0, 1)
110
+
111
+ self.mask_tensor = mask
112
+
113
+ def subtract_mask_tensor(self, mask_tensor: Tensor):
114
+ mask = self.get_mask_tensor()
115
+ if mask is None:
116
+ mask = mask_tensor
117
+ else:
118
+ mask -= mask_tensor
119
+ mask = torch.clamp(mask, 0, 1)
120
+
121
+ self.mask_tensor = mask
122
+
123
+ def save_mask(self):
124
+ if self.mask_tensor is not None:
125
+ mask = self.mask_tensor.cpu().squeeze()
126
+ mask = self.tensor2Image(mask).convert('RGB')
127
+ mask.save(self.mask_filename)
128
+
129
+
130
+ class ClipSeg:
131
+ def __init__(self):
132
+ self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
133
+
134
+ self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
135
+ self.model.eval()
136
+ self.model.to(DEVICE)
137
+
138
+ self.smoothing_kernel_radius = None
139
+ self.smoothing_kernel = self.__create_average_kernel(self.smoothing_kernel_radius)
140
+
141
+ self.expand_kernel_radius = None
142
+ self.expand_kernel = self.__create_average_kernel(self.expand_kernel_radius)
143
+
144
+ @staticmethod
145
+ def __create_average_kernel(kernel_radius: Optional[int]):
146
+ if kernel_radius is None:
147
+ return None
148
+
149
+ kernel_size = kernel_radius * 2 + 1
150
+ kernel_weights = torch.ones(1, 1, kernel_size, kernel_size) / (kernel_size * kernel_size)
151
+ kernel = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel_size, bias=False, padding_mode='replicate', padding=kernel_radius)
152
+ kernel.weight.data = kernel_weights
153
+ kernel.requires_grad_(False)
154
+ kernel.to(DEVICE)
155
+ return kernel
156
+
157
+ @staticmethod
158
+ def __get_sample_filenames(sample_dir: str) -> [str]:
159
+ filenames = []
160
+ for filename in os.listdir(sample_dir):
161
+ ext = os.path.splitext(filename)[1].lower()
162
+ if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp'] and '-masklabel.png' not in filename:
163
+ filenames.append(os.path.join(sample_dir, filename))
164
+
165
+ return filenames
166
+
167
+ def __process_mask(self, mask: Tensor, target_height: int, target_width: int, threshold: float) -> Tensor:
168
+ while len(mask.shape) < 4:
169
+ mask = mask.unsqueeze(0)
170
+
171
+ mask = torch.sigmoid(mask)
172
+ mask = mask.sum(1).unsqueeze(1)
173
+ if self.smoothing_kernel is not None:
174
+ mask = self.smoothing_kernel(mask)
175
+ mask = functional.resize(mask, [target_height, target_width])
176
+ mask = (mask > threshold).float()
177
+ if self.expand_kernel is not None:
178
+ mask = self.expand_kernel(mask)
179
+ mask = (mask > 0).float()
180
+
181
+ return mask
182
+
183
+ def mask_image(self, filename: str, prompts: [str], mode: str = 'fill', threshold: float = 0.3, smooth_pixels: int = 5, expand_pixels: int = 10):
184
+ """
185
+ Masks a sample
186
+
187
+ Parameters:
188
+ filename (`str`): a sample filename
189
+ prompts (`[str]`): a list of prompts used to create a mask
190
+ mode (`str`): can be one of
191
+ - replace: creates new masks for all samples, even if a mask already exists
192
+ - fill: creates new masks for all samples without a mask
193
+ - add: adds the new region to existing masks
194
+ - subtract: subtracts the new region from existing masks
195
+ threshold (`float`): threshold for including pixels in the mask
196
+ smooth_pixels (`int`): radius of a smoothing operation applied to the generated mask
197
+ expand_pixels (`int`): amount of expansion of the generated mask in all directions
198
+ """
199
+
200
+ mask_sample = MaskSample(filename)
201
+
202
+ if mode == 'fill' and mask_sample.get_mask_tensor() is not None:
203
+ return
204
+
205
+ if self.smoothing_kernel_radius != smooth_pixels:
206
+ self.smoothing_kernel = self.__create_average_kernel(smooth_pixels)
207
+ self.smoothing_kernel_radius = smooth_pixels
208
+
209
+ if self.expand_kernel_radius != expand_pixels:
210
+ self.expand_kernel = self.__create_average_kernel(expand_pixels)
211
+ self.expand_kernel_radius = expand_pixels
212
+
213
+ inputs = self.processor(text=prompts, images=[mask_sample.get_image()] * len(prompts), padding="max_length", return_tensors="pt")
214
+ inputs.to(DEVICE)
215
+ with torch.no_grad():
216
+ outputs = self.model(**inputs)
217
+ predicted_mask = self.__process_mask(outputs.logits, mask_sample.height, mask_sample.width, threshold)
218
+
219
+ if mode == 'replace' or mode == 'fill':
220
+ mask_sample.set_mask_tensor(predicted_mask)
221
+ elif mode == 'add':
222
+ mask_sample.add_mask_tensor(predicted_mask)
223
+ elif mode == 'subtract':
224
+ mask_sample.subtract_mask_tensor(predicted_mask)
225
+
226
+ mask_sample.save_mask()
227
+
228
+ def mask_folder(
229
+ self,
230
+ sample_dir: str,
231
+ prompts: [str],
232
+ mode: str = 'fill',
233
+ threshold: float = 0.3,
234
+ smooth_pixels: int = 5,
235
+ expand_pixels: int = 10,
236
+ progress_callback: Callable[[int, int], None] = None,
237
+ error_callback: Callable[[str], None] = None,
238
+ ):
239
+ """
240
+ Masks all samples in a folder
241
+
242
+ Parameters:
243
+ sample_dir (`str`): directory where samples are located
244
+ prompts (`[str]`): a list of prompts used to create a mask
245
+ mode (`str`): can be one of
246
+ - replace: creates new masks for all samples, even if a mask already exists
247
+ - fill: creates new masks for all samples without a mask
248
+ - add: adds the new region to existing masks
249
+ - subtract: subtracts the new region from existing masks
250
+ threshold (`float`): threshold for including pixels in the mask
251
+ smooth_pixels (`int`): radius of a smoothing operation applied to the generated mask
252
+ expand_pixels (`int`): amount of expansion of the generated mask in all directions
253
+ progress_callback (`Callable[[int, int], None]`): called after every processed image
254
+ error_callback (`Callable[[str], None]`): called for every exception
255
+ """
256
+
257
+ filenames = self.__get_sample_filenames(sample_dir)
258
+ self.mask_images(
259
+ filenames=filenames,
260
+ prompts=prompts,
261
+ mode=mode,
262
+ threshold=threshold,
263
+ smooth_pixels=smooth_pixels,
264
+ expand_pixels=expand_pixels,
265
+ progress_callback=progress_callback,
266
+ error_callback=error_callback,
267
+ )
268
+
269
+ def mask_images(
270
+ self,
271
+ filenames: [str],
272
+ prompts: [str],
273
+ mode: str = 'fill',
274
+ threshold: float = 0.3,
275
+ smooth_pixels: int = 5,
276
+ expand_pixels: int = 10,
277
+ progress_callback: Callable[[int, int], None] = None,
278
+ error_callback: Callable[[str], None] = None,
279
+ ):
280
+ """
281
+ Masks all samples in a list
282
+
283
+ Parameters:
284
+ filenames (`[str]`): a list of sample filenames
285
+ prompts (`[str]`): a list of prompts used to create a mask
286
+ mode (`str`): can be one of
287
+ - replace: creates new masks for all samples, even if a mask already exists
288
+ - fill: creates new masks for all samples without a mask
289
+ - add: adds the new region to existing masks
290
+ - subtract: subtracts the new region from existing masks
291
+ threshold (`float`): threshold for including pixels in the mask
292
+ smooth_pixels (`int`): radius of a smoothing operation applied to the generated mask
293
+ expand_pixels (`int`): amount of expansion of the generated mask in all directions
294
+ progress_callback (`Callable[[int, int], None]`): called after every processed image
295
+ error_callback (`Callable[[str], None]`): called for every exception
296
+ """
297
+
298
+ if progress_callback is not None:
299
+ progress_callback(0, len(filenames))
300
+ for i, filename in enumerate(tqdm(filenames)):
301
+ try:
302
+ self.mask_image(filename, prompts, mode, threshold, smooth_pixels, expand_pixels)
303
+ except Exception as e:
304
+ if error_callback is not None:
305
+ error_callback(filename)
306
+ if progress_callback is not None:
307
+ progress_callback(i + 1, len(filenames))
308
+
309
+
310
+ def main():
311
+ args = parse_args()
312
+ clip_seg = ClipSeg()
313
+ clip_seg.mask_folder(
314
+ sample_dir=args.sample_dir,
315
+ prompts=args.prompts,
316
+ mode=args.mode,
317
+ threshold=args.threshold,
318
+ smooth_pixels=args.smooth_pixels,
319
+ expand_pixels=args.expand_pixels,
320
+ error_callback=lambda filename: print("Error while processing image " + filename)
321
+ )
322
+
323
+
324
+ if __name__ == "__main__":
325
+ main()
StableTuner_RunPod_Fix/configuration_gui.py ADDED
The diff for this file is too large to render. See raw diff
 
StableTuner_RunPod_Fix/convert_diffusers_to_sd_cli.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ try:
4
+ import converters
5
+ except ImportError:
6
+
7
+ #if there's a scripts folder where the script is, add it to the path
8
+ if 'scripts' in os.listdir(os.path.dirname(os.path.abspath(__file__))):
9
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '\\scripts')
10
+ else:
11
+ print('Could not find scripts folder. Please add it to the path manually or place this file in it.')
12
+ import converters
13
+
14
+
15
+ if __name__ == '__main__':
16
+ args = sys.argv[1:]
17
+ if len(args) != 2:
18
+ print('Usage: python3 convert_diffusers_to_sd.py <model_path> <output_path>')
19
+ sys.exit(1)
20
+ model_path = args[0]
21
+ output_path = args[1]
22
+ converters.Convert_Diffusers_to_SD(model_path, output_path)
StableTuner_RunPod_Fix/converters.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import requests
16
+ import os
17
+ import os.path as osp
18
+ import torch
19
+ try:
20
+ from omegaconf import OmegaConf
21
+ except ImportError:
22
+ raise ImportError(
23
+ "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
24
+ )
25
+
26
+ from diffusers import (
27
+ AutoencoderKL,
28
+ DDIMScheduler,
29
+ DPMSolverMultistepScheduler,
30
+ EulerAncestralDiscreteScheduler,
31
+ EulerDiscreteScheduler,
32
+ HeunDiscreteScheduler,
33
+ LDMTextToImagePipeline,
34
+ LMSDiscreteScheduler,
35
+ PNDMScheduler,
36
+ StableDiffusionPipeline,
37
+ UNet2DConditionModel,
38
+ DiffusionPipeline
39
+ )
40
+ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
41
+ #from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
42
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
43
+ from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig, CLIPTextConfig
44
+ import model_util
45
+
46
+ class Convert_SD_to_Diffusers():
47
+
48
+ def __init__(self, checkpoint_path, output_path, prediction_type=None, img_size=None, original_config_file=None, extract_ema=False, num_in_channels=None,pipeline_type=None,scheduler_type=None,sd_version=None,half=None,version=None):
49
+ self.checkpoint_path = checkpoint_path
50
+ self.output_path = output_path
51
+ self.prediction_type = prediction_type
52
+ self.img_size = img_size
53
+ self.original_config_file = original_config_file
54
+ self.extract_ema = extract_ema
55
+ self.num_in_channels = num_in_channels
56
+ self.pipeline_type = pipeline_type
57
+ self.scheduler_type = scheduler_type
58
+ self.sd_version = sd_version
59
+ self.half = half
60
+ self.version = version
61
+ self.main()
62
+
63
+
64
+ def main(self):
65
+ image_size = self.img_size
66
+ prediction_type = self.prediction_type
67
+ original_config_file = self.original_config_file
68
+ num_in_channels = self.num_in_channels
69
+ scheduler_type = self.scheduler_type
70
+ pipeline_type = self.pipeline_type
71
+ extract_ema = self.extract_ema
72
+ reference_diffusers_model = None
73
+ if self.version == 'v1':
74
+ is_v1 = True
75
+ is_v2 = False
76
+ if self.version == 'v2':
77
+ is_v1 = False
78
+ is_v2 = True
79
+ if is_v2 == True and prediction_type == 'vprediction':
80
+ reference_diffusers_model = 'stabilityai/stable-diffusion-2'
81
+ if is_v2 == True and prediction_type == 'epsilon':
82
+ reference_diffusers_model = 'stabilityai/stable-diffusion-2-base'
83
+ if is_v1 == True and prediction_type == 'epsilon':
84
+ reference_diffusers_model = 'runwayml/stable-diffusion-v1-5'
85
+ dtype = 'fp16' if self.half else None
86
+ v2_model = True if is_v2 else False
87
+ print(f"loading model from: {self.checkpoint_path}")
88
+ #print(v2_model)
89
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, self.checkpoint_path)
90
+ print(f"copy scheduler/tokenizer config from: {reference_diffusers_model}")
91
+ model_util.save_diffusers_checkpoint(v2_model, self.output_path, text_encoder, unet, reference_diffusers_model, vae)
92
+ print(f"Diffusers model saved.")
93
+
94
+
95
+
96
+ class Convert_Diffusers_to_SD():
97
+ def __init__(self,model_path=None, output_path=None):
98
+ pass
99
+ def main(model_path:str, output_path:str):
100
+ #print(model_path)
101
+ #print(output_path)
102
+ global_step = None
103
+ epoch = None
104
+ dtype = torch.float32
105
+ pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype, tokenizer=None, safety_checker=None)
106
+ text_encoder = pipe.text_encoder
107
+ vae = pipe.vae
108
+ if os.path.exists(os.path.join(model_path, "ema_unet")):
109
+ pipe.unet = UNet2DConditionModel.from_pretrained(
110
+ model_path,
111
+ subfolder="ema_unet",
112
+ torch_dtype=dtype
113
+ )
114
+ unet = pipe.unet
115
+ v2_model = unet.config.cross_attention_dim == 1024
116
+ original_model = None
117
+ key_count = model_util.save_stable_diffusion_checkpoint(v2_model, output_path, text_encoder, unet,
118
+ original_model, epoch, global_step, dtype, vae)
119
+ print(f"Saved model")
120
+ return main(model_path, output_path)
StableTuner_RunPod_Fix/dataloaders_util.py ADDED
@@ -0,0 +1,1331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import math
3
+ import os
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from torch.utils.data import Dataset
7
+ from torchvision import transforms
8
+ from tqdm.auto import tqdm
9
+ import numpy as np
10
+ from PIL import Image
11
+ from trainer_util import *
12
+ from clip_segmentation import ClipSeg
13
+
14
+ class bcolors:
15
+ HEADER = '\033[95m'
16
+ OKBLUE = '\033[94m'
17
+ OKCYAN = '\033[96m'
18
+ OKGREEN = '\033[92m'
19
+ WARNING = '\033[93m'
20
+ FAIL = '\033[91m'
21
+ ENDC = '\033[0m'
22
+ BOLD = '\033[1m'
23
+ UNDERLINE = '\033[4m'
24
+ ASPECT_2048 = [[2048, 2048],
25
+ [2112, 1984],[1984, 2112],
26
+ [2176, 1920],[1920, 2176],
27
+ [2240, 1856],[1856, 2240],
28
+ [2304, 1792],[1792, 2304],
29
+ [2368, 1728],[1728, 2368],
30
+ [2432, 1664],[1664, 2432],
31
+ [2496, 1600],[1600, 2496],
32
+ [2560, 1536],[1536, 2560],
33
+ [2624, 1472],[1472, 2624]]
34
+ ASPECT_1984 = [[1984, 1984],
35
+ [2048, 1920],[1920, 2048],
36
+ [2112, 1856],[1856, 2112],
37
+ [2176, 1792],[1792, 2176],
38
+ [2240, 1728],[1728, 2240],
39
+ [2304, 1664],[1664, 2304],
40
+ [2368, 1600],[1600, 2368],
41
+ [2432, 1536],[1536, 2432],
42
+ [2496, 1472],[1472, 2496],
43
+ [2560, 1408],[1408, 2560]]
44
+ ASPECT_1920 = [[1920, 1920],
45
+ [1984, 1856],[1856, 1984],
46
+ [2048, 1792],[1792, 2048],
47
+ [2112, 1728],[1728, 2112],
48
+ [2176, 1664],[1664, 2176],
49
+ [2240, 1600],[1600, 2240],
50
+ [2304, 1536],[1536, 2304],
51
+ [2368, 1472],[1472, 2368],
52
+ [2432, 1408],[1408, 2432],
53
+ [2496, 1344],[1344, 2496]]
54
+ ASPECT_1856 = [[1856, 1856],
55
+ [1920, 1792],[1792, 1920],
56
+ [1984, 1728],[1728, 1984],
57
+ [2048, 1664],[1664, 2048],
58
+ [2112, 1600],[1600, 2112],
59
+ [2176, 1536],[1536, 2176],
60
+ [2240, 1472],[1472, 2240],
61
+ [2304, 1408],[1408, 2304],
62
+ [2368, 1344],[1344, 2368],
63
+ [2432, 1280],[1280, 2432]]
64
+ ASPECT_1792 = [[1792, 1792],
65
+ [1856, 1728],[1728, 1856],
66
+ [1920, 1664],[1664, 1920],
67
+ [1984, 1600],[1600, 1984],
68
+ [2048, 1536],[1536, 2048],
69
+ [2112, 1472],[1472, 2112],
70
+ [2176, 1408],[1408, 2176],
71
+ [2240, 1344],[1344, 2240],
72
+ [2304, 1280],[1280, 2304],
73
+ [2368, 1216],[1216, 2368]]
74
+ ASPECT_1728 = [[1728, 1728],
75
+ [1792, 1664],[1664, 1792],
76
+ [1856, 1600],[1600, 1856],
77
+ [1920, 1536],[1536, 1920],
78
+ [1984, 1472],[1472, 1984],
79
+ [2048, 1408],[1408, 2048],
80
+ [2112, 1344],[1344, 2112],
81
+ [2176, 1280],[1280, 2176],
82
+ [2240, 1216],[1216, 2240],
83
+ [2304, 1152],[1152, 2304]]
84
+ ASPECT_1664 = [[1664, 1664],
85
+ [1728, 1600],[1600, 1728],
86
+ [1792, 1536],[1536, 1792],
87
+ [1856, 1472],[1472, 1856],
88
+ [1920, 1408],[1408, 1920],
89
+ [1984, 1344],[1344, 1984],
90
+ [2048, 1280],[1280, 2048],
91
+ [2112, 1216],[1216, 2112],
92
+ [2176, 1152],[1152, 2176],
93
+ [2240, 1088],[1088, 2240]]
94
+ ASPECT_1600 = [[1600, 1600],
95
+ [1664, 1536],[1536, 1664],
96
+ [1728, 1472],[1472, 1728],
97
+ [1792, 1408],[1408, 1792],
98
+ [1856, 1344],[1344, 1856],
99
+ [1920, 1280],[1280, 1920],
100
+ [1984, 1216],[1216, 1984],
101
+ [2048, 1152],[1152, 2048],
102
+ [2112, 1088],[1088, 2112],
103
+ [2176, 1024],[1024, 2176]]
104
+ ASPECT_1536 = [[1536, 1536],
105
+ [1600, 1472],[1472, 1600],
106
+ [1664, 1408],[1408, 1664],
107
+ [1728, 1344],[1344, 1728],
108
+ [1792, 1280],[1280, 1792],
109
+ [1856, 1216],[1216, 1856],
110
+ [1920, 1152],[1152, 1920],
111
+ [1984, 1088],[1088, 1984],
112
+ [2048, 1024],[1024, 2048],
113
+ [2112, 960],[960, 2112]]
114
+ ASPECT_1472 = [[1472, 1472],
115
+ [1536, 1408],[1408, 1536],
116
+ [1600, 1344],[1344, 1600],
117
+ [1664, 1280],[1280, 1664],
118
+ [1728, 1216],[1216, 1728],
119
+ [1792, 1152],[1152, 1792],
120
+ [1856, 1088],[1088, 1856],
121
+ [1920, 1024],[1024, 1920],
122
+ [1984, 960],[960, 1984],
123
+ [2048, 896],[896, 2048]]
124
+ ASPECT_1408 = [[1408, 1408],
125
+ [1472, 1344],[1344, 1472],
126
+ [1536, 1280],[1280, 1536],
127
+ [1600, 1216],[1216, 1600],
128
+ [1664, 1152],[1152, 1664],
129
+ [1728, 1088],[1088, 1728],
130
+ [1792, 1024],[1024, 1792],
131
+ [1856, 960],[960, 1856],
132
+ [1920, 896],[896, 1920],
133
+ [1984, 832],[832, 1984]]
134
+ ASPECT_1344 = [[1344, 1344],
135
+ [1408, 1280],[1280, 1408],
136
+ [1472, 1216],[1216, 1472],
137
+ [1536, 1152],[1152, 1536],
138
+ [1600, 1088],[1088, 1600],
139
+ [1664, 1024],[1024, 1664],
140
+ [1728, 960],[960, 1728],
141
+ [1792, 896],[896, 1792],
142
+ [1856, 832],[832, 1856],
143
+ [1920, 768],[768, 1920]]
144
+ ASPECT_1280 = [[1280, 1280],
145
+ [1344, 1216],[1216, 1344],
146
+ [1408, 1152],[1152, 1408],
147
+ [1472, 1088],[1088, 1472],
148
+ [1536, 1024],[1024, 1536],
149
+ [1600, 960],[960, 1600],
150
+ [1664, 896],[896, 1664],
151
+ [1728, 832],[832, 1728],
152
+ [1792, 768],[768, 1792],
153
+ [1856, 704],[704, 1856]]
154
+ ASPECT_1216 = [[1216, 1216],
155
+ [1280, 1152],[1152, 1280],
156
+ [1344, 1088],[1088, 1344],
157
+ [1408, 1024],[1024, 1408],
158
+ [1472, 960],[960, 1472],
159
+ [1536, 896],[896, 1536],
160
+ [1600, 832],[832, 1600],
161
+ [1664, 768],[768, 1664],
162
+ [1728, 704],[704, 1728],
163
+ [1792, 640],[640, 1792]]
164
+ ASPECT_1152 = [[1152, 1152],
165
+ [1216, 1088],[1088, 1216],
166
+ [1280, 1024],[1024, 1280],
167
+ [1344, 960],[960, 1344],
168
+ [1408, 896],[896, 1408],
169
+ [1472, 832],[832, 1472],
170
+ [1536, 768],[768, 1536],
171
+ [1600, 704],[704, 1600],
172
+ [1664, 640],[640, 1664],
173
+ [1728, 576],[576, 1728]]
174
+ ASPECT_1088 = [[1088, 1088],
175
+ [1152, 1024],[1024, 1152],
176
+ [1216, 960],[960, 1216],
177
+ [1280, 896],[896, 1280],
178
+ [1344, 832],[832, 1344],
179
+ [1408, 768],[768, 1408],
180
+ [1472, 704],[704, 1472],
181
+ [1536, 640],[640, 1536],
182
+ [1600, 576],[576, 1600],
183
+ [1664, 512],[512, 1664]]
184
+ ASPECT_832 = [[832, 832],
185
+ [896, 768], [768, 896],
186
+ [960, 704], [704, 960],
187
+ [1024, 640], [640, 1024],
188
+ [1152, 576], [576, 1152],
189
+ [1280, 512], [512, 1280],
190
+ [1344, 512], [512, 1344],
191
+ [1408, 448], [448, 1408],
192
+ [1472, 448], [448, 1472],
193
+ [1536, 384], [384, 1536],
194
+ [1600, 384], [384, 1600]]
195
+
196
+ ASPECT_896 = [[896, 896],
197
+ [960, 832], [832, 960],
198
+ [1024, 768], [768, 1024],
199
+ [1088, 704], [704, 1088],
200
+ [1152, 704], [704, 1152],
201
+ [1216, 640], [640, 1216],
202
+ [1280, 640], [640, 1280],
203
+ [1344, 576], [576, 1344],
204
+ [1408, 576], [576, 1408],
205
+ [1472, 512], [512, 1472],
206
+ [1536, 512], [512, 1536],
207
+ [1600, 448], [448, 1600],
208
+ [1664, 448], [448, 1664]]
209
+ ASPECT_960 = [[960, 960],
210
+ [1024, 896],[896, 1024],
211
+ [1088, 832],[832, 1088],
212
+ [1152, 768],[768, 1152],
213
+ [1216, 704],[704, 1216],
214
+ [1280, 640],[640, 1280],
215
+ [1344, 576],[576, 1344],
216
+ [1408, 512],[512, 1408],
217
+ [1472, 448],[448, 1472],
218
+ [1536, 384],[384, 1536]]
219
+ ASPECT_1024 = [[1024, 1024],
220
+ [1088, 960], [960, 1088],
221
+ [1152, 896], [896, 1152],
222
+ [1216, 832], [832, 1216],
223
+ [1344, 768], [768, 1344],
224
+ [1472, 704], [704, 1472],
225
+ [1600, 640], [640, 1600],
226
+ [1728, 576], [576, 1728],
227
+ [1792, 576], [576, 1792]]
228
+ ASPECT_768 = [[768,768], # 589824 1:1
229
+ [896,640],[640,896], # 573440 1.4:1
230
+ [832,704],[704,832], # 585728 1.181:1
231
+ [960,576],[576,960], # 552960 1.6:1
232
+ [1024,576],[576,1024], # 524288 1.778:1
233
+ [1088,512],[512,1088], # 497664 2.125:1
234
+ [1152,512],[512,1152], # 589824 2.25:1
235
+ [1216,448],[448,1216], # 552960 2.714:1
236
+ [1280,448],[448,1280], # 573440 2.857:1
237
+ [1344,384],[384,1344], # 518400 3.5:1
238
+ [1408,384],[384,1408], # 540672 3.667:1
239
+ [1472,320],[320,1472], # 470400 4.6:1
240
+ [1536,320],[320,1536], # 491520 4.8:1
241
+ ]
242
+
243
+ ASPECT_704 = [[704,704], # 501,376 1:1
244
+ [768,640],[640,768], # 491,520 1.2:1
245
+ [832,576],[576,832], # 458,752 1.444:1
246
+ [896,512],[512,896], # 458,752 1.75:1
247
+ [960,512],[512,960], # 491,520 1.875:1
248
+ [1024,448],[448,1024], # 458,752 2.286:1
249
+ [1088,448],[448,1088], # 487,424 2.429:1
250
+ [1152,384],[384,1152], # 442,368 3:1
251
+ [1216,384],[384,1216], # 466,944 3.125:1
252
+ [1280,384],[384,1280], # 491,520 3.333:1
253
+ [1280,320],[320,1280], # 409,600 4:1
254
+ [1408,320],[320,1408], # 450,560 4.4:1
255
+ [1536,320],[320,1536], # 491,520 4.8:1
256
+ ]
257
+
258
+ ASPECT_640 = [[640,640], # 409600 1:1
259
+ [704,576],[576,704], # 405504 1.25:1
260
+ [768,512],[512,768], # 393216 1.5:1
261
+ [896,448],[448,896], # 401408 2:1
262
+ [1024,384],[384,1024], # 393216 2.667:1
263
+ [1280,320],[320,1280], # 409600 4:1
264
+ [1408,256],[256,1408], # 360448 5.5:1
265
+ [1472,256],[256,1472], # 376832 5.75:1
266
+ [1536,256],[256,1536], # 393216 6:1
267
+ [1600,256],[256,1600], # 409600 6.25:1
268
+ ]
269
+
270
+ ASPECT_576 = [[576,576], # 331776 1:1
271
+ [640,512],[512,640], # 327680 1.25:1
272
+ [640,448],[448,640], # 286720 1.4286:1
273
+ [704,448],[448,704], # 314928 1.5625:1
274
+ [832,384],[384,832], # 317440 2.1667:1
275
+ [1024,320],[320,1024], # 327680 3.2:1
276
+ [1280,256],[256,1280], # 327680 5:1
277
+ ]
278
+
279
+ ASPECT_512 = [[512,512], # 262144 1:1
280
+ [576,448],[448,576], # 258048 1.29:1
281
+ [640,384],[384,640], # 245760 1.667:1
282
+ [768,320],[320,768], # 245760 2.4:1
283
+ [832,256],[256,832], # 212992 3.25:1
284
+ [896,256],[256,896], # 229376 3.5:1
285
+ [960,256],[256,960], # 245760 3.75:1
286
+ [1024,256],[256,1024], # 245760 4:1
287
+ ]
288
+
289
+ ASPECT_448 = [[448,448], # 200704 1:1
290
+ [512,384],[384,512], # 196608 1.33:1
291
+ [576,320],[320,576], # 184320 1.8:1
292
+ [768,256],[256,768], # 196608 3:1
293
+ ]
294
+
295
+ ASPECT_384 = [[384,384], # 147456 1:1
296
+ [448,320],[320,448], # 143360 1.4:1
297
+ [576,256],[256,576], # 147456 2.25:1
298
+ [768,192],[192,768], # 147456 4:1
299
+ ]
300
+
301
+ ASPECT_320 = [[320,320], # 102400 1:1
302
+ [384,256],[256,384], # 98304 1.5:1
303
+ [512,192],[192,512], # 98304 2.67:1
304
+ ]
305
+
306
+ ASPECT_256 = [[256,256], # 65536 1:1
307
+ [320,192],[192,320], # 61440 1.67:1
308
+ [512,128],[128,512], # 65536 4:1
309
+ ]
310
+
311
+ #failsafe aspects
312
+ ASPECTS = ASPECT_512
313
+ def get_aspect_buckets(resolution,mode=''):
314
+ if resolution < 256:
315
+ raise ValueError("Resolution must be at least 512")
316
+ try:
317
+ rounded_resolution = int(resolution / 64) * 64
318
+ print(f" {bcolors.WARNING} Rounded resolution to: {rounded_resolution}{bcolors.ENDC}")
319
+ all_image_sizes = __get_all_aspects()
320
+ if mode == 'MJ':
321
+ #truncate to the first 3 resolutions
322
+ all_image_sizes = [x[0:3] for x in all_image_sizes]
323
+ aspects = next(filter(lambda sizes: sizes[0][0]==rounded_resolution, all_image_sizes), None)
324
+ ASPECTS = aspects
325
+ #print(aspects)
326
+ return aspects
327
+ except Exception as e:
328
+ print(f" {bcolors.FAIL} *** Could not find selected resolution: {rounded_resolution}{bcolors.ENDC}")
329
+
330
+ raise e
331
+
332
+ def __get_all_aspects():
333
+ return [ASPECT_256, ASPECT_320, ASPECT_384, ASPECT_448, ASPECT_512, ASPECT_576, ASPECT_640, ASPECT_704, ASPECT_768,ASPECT_832,ASPECT_896,ASPECT_960,ASPECT_1024,ASPECT_1088,ASPECT_1152,ASPECT_1216,ASPECT_1280,ASPECT_1344,ASPECT_1408,ASPECT_1472,ASPECT_1536,ASPECT_1600,ASPECT_1664,ASPECT_1728,ASPECT_1792,ASPECT_1856,ASPECT_1920,ASPECT_1984,ASPECT_2048]
334
+ class AutoBucketing(Dataset):
335
+ def __init__(self,
336
+ concepts_list,
337
+ tokenizer=None,
338
+ flip_p=0.0,
339
+ repeats=1,
340
+ debug_level=0,
341
+ batch_size=1,
342
+ set='val',
343
+ resolution=512,
344
+ center_crop=False,
345
+ use_image_names_as_captions=True,
346
+ shuffle_captions=False,
347
+ add_class_images_to_dataset=None,
348
+ balance_datasets=False,
349
+ crop_jitter=20,
350
+ with_prior_loss=False,
351
+ use_text_files_as_captions=False,
352
+ aspect_mode='dynamic',
353
+ action_preference='dynamic',
354
+ seed=555,
355
+ model_variant='base',
356
+ extra_module=None,
357
+ mask_prompts=None,
358
+ load_mask=False,
359
+ ):
360
+
361
+ self.debug_level = debug_level
362
+ self.resolution = resolution
363
+ self.center_crop = center_crop
364
+ self.tokenizer = tokenizer
365
+ self.batch_size = batch_size
366
+ self.concepts_list = concepts_list
367
+ self.use_image_names_as_captions = use_image_names_as_captions
368
+ self.shuffle_captions = shuffle_captions
369
+ self.num_train_images = 0
370
+ self.num_reg_images = 0
371
+ self.image_train_items = []
372
+ self.image_reg_items = []
373
+ self.add_class_images_to_dataset = add_class_images_to_dataset
374
+ self.balance_datasets = balance_datasets
375
+ self.crop_jitter = crop_jitter
376
+ self.with_prior_loss = with_prior_loss
377
+ self.use_text_files_as_captions = use_text_files_as_captions
378
+ self.aspect_mode = aspect_mode
379
+ self.action_preference = action_preference
380
+ self.model_variant = model_variant
381
+ self.extra_module = extra_module
382
+ self.image_transforms = transforms.Compose(
383
+ [
384
+ transforms.ToTensor(),
385
+ transforms.Normalize([0.5], [0.5]),
386
+ ]
387
+ )
388
+ self.mask_transforms = transforms.Compose(
389
+ [
390
+ transforms.ToTensor(),
391
+ ]
392
+ )
393
+ self.depth_image_transforms = transforms.Compose(
394
+ [
395
+ transforms.ToTensor(),
396
+ ]
397
+ )
398
+ self.seed = seed
399
+ #shared_dataloader = None
400
+ print(f" {bcolors.WARNING}Creating Auto Bucketing Dataloader{bcolors.ENDC}")
401
+
402
+ shared_dataloader = DataLoaderMultiAspect(concepts_list,
403
+ debug_level=debug_level,
404
+ resolution=self.resolution,
405
+ seed=self.seed,
406
+ batch_size=self.batch_size,
407
+ flip_p=flip_p,
408
+ use_image_names_as_captions=self.use_image_names_as_captions,
409
+ add_class_images_to_dataset=self.add_class_images_to_dataset,
410
+ balance_datasets=self.balance_datasets,
411
+ with_prior_loss=self.with_prior_loss,
412
+ use_text_files_as_captions=self.use_text_files_as_captions,
413
+ aspect_mode=self.aspect_mode,
414
+ action_preference=self.action_preference,
415
+ model_variant=self.model_variant,
416
+ extra_module=self.extra_module,
417
+ mask_prompts=mask_prompts,
418
+ load_mask=load_mask,
419
+ )
420
+
421
+ #print(self.image_train_items)
422
+ if self.with_prior_loss and self.add_class_images_to_dataset == False:
423
+ self.image_train_items, self.class_train_items = shared_dataloader.get_all_images()
424
+ self.num_train_images = self.num_train_images + len(self.image_train_items)
425
+ self.num_reg_images = self.num_reg_images + len(self.class_train_items)
426
+ self._length = max(max(math.trunc(self.num_train_images * repeats), batch_size),math.trunc(self.num_reg_images * repeats), batch_size) - self.num_train_images % self.batch_size
427
+ self.num_train_images = self.num_train_images + self.num_reg_images
428
+
429
+ else:
430
+ self.image_train_items = shared_dataloader.get_all_images()
431
+ self.num_train_images = self.num_train_images + len(self.image_train_items)
432
+ self._length = max(math.trunc(self.num_train_images * repeats), batch_size) - self.num_train_images % self.batch_size
433
+
434
+ print()
435
+ print(f" {bcolors.WARNING} ** Validation Set: {set}, steps: {self._length / batch_size:.0f}, repeats: {repeats} {bcolors.ENDC}")
436
+ print()
437
+
438
+
439
+ def __len__(self):
440
+ return self._length
441
+
442
+ def __getitem__(self, i):
443
+ idx = i % self.num_train_images
444
+ #print(idx)
445
+ image_train_item = self.image_train_items[idx]
446
+
447
+ example = self.__get_image_for_trainer(image_train_item,debug_level=self.debug_level)
448
+ if self.with_prior_loss and self.add_class_images_to_dataset == False:
449
+ idx = i % self.num_reg_images
450
+ class_train_item = self.class_train_items[idx]
451
+ example_class = self.__get_image_for_trainer(class_train_item,debug_level=self.debug_level,class_img=True)
452
+ example= {**example, **example_class}
453
+
454
+ #print the tensor shape
455
+ #print(example['instance_images'].shape)
456
+ #print(example.keys())
457
+ return example
458
+ def normalize8(self,I):
459
+ mn = I.min()
460
+ mx = I.max()
461
+
462
+ mx -= mn
463
+
464
+ I = ((I - mn)/mx) * 255
465
+ return I.astype(np.uint8)
466
+ def __get_image_for_trainer(self,image_train_item,debug_level=0,class_img=False):
467
+ example = {}
468
+ save = debug_level > 2
469
+
470
+ if class_img==False:
471
+ image_train_tmp = image_train_item.hydrate(crop=False, save=0, crop_jitter=self.crop_jitter)
472
+ image_train_tmp_image = Image.fromarray(self.normalize8(image_train_tmp.image)).convert("RGB")
473
+
474
+ instance_prompt = image_train_tmp.caption
475
+ if self.shuffle_captions:
476
+ caption_parts = instance_prompt.split(",")
477
+ random.shuffle(caption_parts)
478
+ instance_prompt = ",".join(caption_parts)
479
+
480
+ example["instance_images"] = self.image_transforms(image_train_tmp_image)
481
+ if image_train_tmp.mask is not None:
482
+ image_train_tmp_mask = Image.fromarray(self.normalize8(image_train_tmp.mask)).convert("L")
483
+ example["mask"] = self.mask_transforms(image_train_tmp_mask)
484
+ if self.model_variant == 'depth2img':
485
+ image_train_tmp_depth = Image.fromarray(self.normalize8(image_train_tmp.extra)).convert("L")
486
+ example["instance_depth_images"] = self.depth_image_transforms(image_train_tmp_depth)
487
+ #print(instance_prompt)
488
+ example["instance_prompt_ids"] = self.tokenizer(
489
+ instance_prompt,
490
+ padding="do_not_pad",
491
+ truncation=True,
492
+ max_length=self.tokenizer.model_max_length,
493
+ ).input_ids
494
+ image_train_item.self_destruct()
495
+ return example
496
+
497
+ if class_img==True:
498
+ image_train_tmp = image_train_item.hydrate(crop=False, save=4, crop_jitter=self.crop_jitter)
499
+ image_train_tmp_image = Image.fromarray(self.normalize8(image_train_tmp.image)).convert("RGB")
500
+ if self.model_variant == 'depth2img':
501
+ image_train_tmp_depth = Image.fromarray(self.normalize8(image_train_tmp.extra)).convert("L")
502
+ example["class_depth_images"] = self.depth_image_transforms(image_train_tmp_depth)
503
+ example["class_images"] = self.image_transforms(image_train_tmp_image)
504
+ example["class_prompt_ids"] = self.tokenizer(
505
+ image_train_tmp.caption,
506
+ padding="do_not_pad",
507
+ truncation=True,
508
+ max_length=self.tokenizer.model_max_length,
509
+ ).input_ids
510
+ image_train_item.self_destruct()
511
+ return example
512
+
513
+ _RANDOM_TRIM = 0.04
514
+ class ImageTrainItem():
515
+ """
516
+ image: Image
517
+ mask: Image
518
+ extra: Image
519
+ identifier: caption,
520
+ target_aspect: (width, height),
521
+ pathname: path to image file
522
+ flip_p: probability of flipping image (0.0 to 1.0)
523
+ """
524
+ def __init__(self, image: Image, mask: Image, extra: Image, caption: str, target_wh: list, pathname: str, flip_p=0.0, model_variant='base', load_mask=False):
525
+ self.caption = caption
526
+ self.target_wh = target_wh
527
+ self.pathname = pathname
528
+ self.mask_pathname = os.path.splitext(pathname)[0] + "-masklabel.png"
529
+ self.depth_pathname = os.path.splitext(pathname)[0] + "-depth.png"
530
+ self.flip_p = flip_p
531
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
532
+ self.cropped_img = None
533
+ self.model_variant = model_variant
534
+ self.load_mask=load_mask
535
+ self.is_dupe = []
536
+ self.variant_warning = False
537
+
538
+ self.image = image
539
+ self.mask = mask
540
+ self.extra = extra
541
+
542
+ def self_destruct(self):
543
+ self.image = None
544
+ self.mask = None
545
+ self.extra = None
546
+ self.cropped_img = None
547
+ self.is_dupe.append(1)
548
+
549
+ def load_image(self, pathname, crop, jitter_amount, flip):
550
+ if len(self.is_dupe) > 0:
551
+ self.flip = transforms.RandomHorizontalFlip(p=1.0 if flip else 0.0)
552
+ image = Image.open(pathname).convert('RGB')
553
+
554
+ width, height = image.size
555
+ if crop:
556
+ cropped_img = self.__autocrop(image)
557
+ image = cropped_img.resize((512, 512), resample=Image.Resampling.LANCZOS)
558
+ else:
559
+ width, height = image.size
560
+
561
+ if self.target_wh[0] == self.target_wh[1]:
562
+ if width > height:
563
+ left = random.randint(0, width - height)
564
+ image = image.crop((left, 0, height + left, height))
565
+ width = height
566
+ elif height > width:
567
+ top = random.randint(0, height - width)
568
+ image = image.crop((0, top, width, width + top))
569
+ height = width
570
+ elif width > self.target_wh[0]:
571
+ slice = min(int(self.target_wh[0] * _RANDOM_TRIM), width - self.target_wh[0])
572
+ slicew_ratio = random.random()
573
+ left = int(slice * slicew_ratio)
574
+ right = width - int(slice * (1 - slicew_ratio))
575
+ sliceh_ratio = random.random()
576
+ top = int(slice * sliceh_ratio)
577
+ bottom = height - int(slice * (1 - sliceh_ratio))
578
+
579
+ image = image.crop((left, top, right, bottom))
580
+ else:
581
+ image_aspect = width / height
582
+ target_aspect = self.target_wh[0] / self.target_wh[1]
583
+ if image_aspect > target_aspect:
584
+ new_width = int(height * target_aspect)
585
+ jitter_amount = max(min(jitter_amount, int(abs(width - new_width) / 2)), 0)
586
+ left = jitter_amount
587
+ right = left + new_width
588
+ image = image.crop((left, 0, right, height))
589
+ else:
590
+ new_height = int(width / target_aspect)
591
+ jitter_amount = max(min(jitter_amount, int(abs(height - new_height) / 2)), 0)
592
+ top = jitter_amount
593
+ bottom = top + new_height
594
+ image = image.crop((0, top, width, bottom))
595
+ # LAZCOS resample
596
+ image = image.resize(self.target_wh, resample=Image.Resampling.LANCZOS)
597
+ # print the pixel count of the image
598
+ # print path to image file
599
+ # print(self.pathname)
600
+ # print(self.image.size[0] * self.image.size[1])
601
+ image = self.flip(image)
602
+ return image
603
+
604
+ def hydrate(self, crop=False, save=False, crop_jitter=20):
605
+ """
606
+ crop: hard center crop to 512x512
607
+ save: save the cropped image to disk, for manual inspection of resize/crop
608
+ crop_jitter: randomly shift cropp by N pixels when using multiple aspect ratios to improve training quality
609
+ """
610
+
611
+ if self.image is None:
612
+ chance = float(len(self.is_dupe)) / 10.0
613
+
614
+ flip_p = self.flip_p + chance if chance < 1.0 else 1.0
615
+ flip = random.uniform(0, 1) < flip_p
616
+
617
+ if len(self.is_dupe) > 0:
618
+ crop_jitter = crop_jitter + (len(self.is_dupe) * 10) if crop_jitter < 50 else 50
619
+
620
+ jitter_amount = random.randint(0, crop_jitter)
621
+
622
+ self.image = self.load_image(self.pathname, crop, jitter_amount, flip)
623
+
624
+ if self.model_variant == "inpainting" or self.load_mask:
625
+ if os.path.exists(self.mask_pathname) and self.load_mask:
626
+ self.mask = self.load_image(self.mask_pathname, crop, jitter_amount, flip)
627
+ else:
628
+ if self.variant_warning == False:
629
+ print(f" {bcolors.FAIL} ** Warning: No mask found for an image, using an empty mask but make sure you're training the right model variant.{bcolors.ENDC}")
630
+ self.variant_warning = True
631
+ self.mask = Image.new('RGB', self.image.size, color="white").convert("L")
632
+
633
+ if self.model_variant == "depth2img":
634
+ if os.path.exists(self.depth_pathname):
635
+ self.extra = self.load_image(self.depth_pathname, crop, jitter_amount, flip)
636
+ else:
637
+ if self.variant_warning == False:
638
+ print(f" {bcolors.FAIL} ** Warning: No depth found for an image, using an empty depth but make sure you're training the right model variant.{bcolors.ENDC}")
639
+ self.variant_warning = True
640
+ self.extra = Image.new('RGB', self.image.size, color="white").convert("L")
641
+ if type(self.image) is not np.ndarray:
642
+ if save:
643
+ base_name = os.path.basename(self.pathname)
644
+ if not os.path.exists("test/output"):
645
+ os.makedirs("test/output")
646
+ self.image.save(f"test/output/{base_name}")
647
+
648
+ self.image = np.array(self.image).astype(np.uint8)
649
+
650
+ self.image = (self.image / 127.5 - 1.0).astype(np.float32)
651
+ if self.mask is not None and type(self.mask) is not np.ndarray:
652
+ self.mask = np.array(self.mask).astype(np.uint8)
653
+
654
+ self.mask = (self.mask / 255.0).astype(np.float32)
655
+ if self.extra is not None and type(self.extra) is not np.ndarray:
656
+ self.extra = np.array(self.extra).astype(np.uint8)
657
+
658
+ self.extra = (self.extra / 255.0).astype(np.float32)
659
+
660
+ #print(self.image.shape)
661
+
662
+ return self
663
+ class CachedLatentsDataset(Dataset):
664
+ #stores paths and loads latents on the fly
665
+ def __init__(self, cache_paths=(),batch_size=None,tokenizer=None,text_encoder=None,dtype=None,model_variant='base',shuffle_per_epoch=False,args=None):
666
+ self.cache_paths = cache_paths
667
+ self.tokenizer = tokenizer
668
+ self.args = args
669
+ self.text_encoder = text_encoder
670
+ #get text encoder device
671
+ text_encoder_device = next(self.text_encoder.parameters()).device
672
+ self.empty_batch = [self.tokenizer('',padding="do_not_pad",truncation=True,max_length=self.tokenizer.model_max_length,).input_ids for i in range(batch_size)]
673
+ #handle text encoder for empty tokens
674
+ if self.args.train_text_encoder != True:
675
+ self.empty_tokens = tokenizer.pad({"input_ids": self.empty_batch},padding="max_length",max_length=tokenizer.model_max_length,return_tensors="pt",).to(text_encoder_device).input_ids
676
+ self.empty_tokens.to(text_encoder_device, dtype=dtype)
677
+ self.empty_tokens = self.text_encoder(self.empty_tokens)[0]
678
+ else:
679
+ self.empty_tokens = tokenizer.pad({"input_ids": self.empty_batch},padding="max_length",max_length=tokenizer.model_max_length,return_tensors="pt",).input_ids
680
+ self.empty_tokens.to(text_encoder_device, dtype=dtype)
681
+
682
+ self.conditional_dropout = args.conditional_dropout
683
+ self.conditional_indexes = []
684
+ self.model_variant = model_variant
685
+ self.shuffle_per_epoch = shuffle_per_epoch
686
+ def __len__(self):
687
+ return len(self.cache_paths)
688
+ def __getitem__(self, index):
689
+ if index == 0:
690
+ if self.shuffle_per_epoch == True:
691
+ self.cache_paths = tuple(random.sample(self.cache_paths, len(self.cache_paths)))
692
+ if len(self.cache_paths) > 1:
693
+ possible_indexes_extension = None
694
+ possible_indexes = list(range(0,len(self.cache_paths)))
695
+ #conditional dropout is a percentage of images to drop from the total cache_paths
696
+ if self.conditional_dropout != None:
697
+ if len(self.conditional_indexes) == 0:
698
+ self.conditional_indexes = random.sample(possible_indexes, k=int(math.ceil(len(possible_indexes)*self.conditional_dropout)))
699
+ else:
700
+ #pick indexes from the remaining possible indexes
701
+ possible_indexes_extension = [i for i in possible_indexes if i not in self.conditional_indexes]
702
+ #duplicate all values in possible_indexes_extension
703
+ possible_indexes_extension = possible_indexes_extension + possible_indexes_extension
704
+ possible_indexes_extension = possible_indexes_extension + self.conditional_indexes
705
+ self.conditional_indexes = random.sample(possible_indexes_extension, k=int(math.ceil(len(possible_indexes)*self.conditional_dropout)))
706
+ #check for duplicates in conditional_indexes values
707
+ if len(self.conditional_indexes) != len(set(self.conditional_indexes)):
708
+ #remove duplicates
709
+ self.conditional_indexes_non_dupe = list(set(self.conditional_indexes))
710
+ #add a random value from possible_indexes_extension for each duplicate
711
+ for i in range(len(self.conditional_indexes) - len(self.conditional_indexes_non_dupe)):
712
+ while True:
713
+ random_value = random.choice(possible_indexes_extension)
714
+ if random_value not in self.conditional_indexes_non_dupe:
715
+ self.conditional_indexes_non_dupe.append(random_value)
716
+ break
717
+ self.conditional_indexes = self.conditional_indexes_non_dupe
718
+ self.cache = torch.load(self.cache_paths[index])
719
+ self.latents = self.cache.latents_cache[0]
720
+ self.tokens = self.cache.tokens_cache[0]
721
+ self.extra_cache = None
722
+ self.mask_cache = None
723
+ if self.cache.mask_cache is not None:
724
+ self.mask_cache = self.cache.mask_cache[0]
725
+ self.mask_mean_cache = None
726
+ if self.cache.mask_mean_cache is not None:
727
+ self.mask_mean_cache = self.cache.mask_mean_cache[0]
728
+ if index in self.conditional_indexes:
729
+ self.text_encoder = self.empty_tokens
730
+ else:
731
+ self.text_encoder = self.cache.text_encoder_cache[0]
732
+ if self.model_variant != 'base':
733
+ self.extra_cache = self.cache.extra_cache[0]
734
+ del self.cache
735
+ return self.latents, self.text_encoder, self.mask_cache, self.mask_mean_cache, self.extra_cache, self.tokens
736
+
737
+ def add_pt_cache(self, cache_path):
738
+ if len(self.cache_paths) == 0:
739
+ self.cache_paths = (cache_path,)
740
+ else:
741
+ self.cache_paths += (cache_path,)
742
+
743
+ class LatentsDataset(Dataset):
744
+ def __init__(self, latents_cache=None, text_encoder_cache=None, mask_cache=None, mask_mean_cache=None, extra_cache=None,tokens_cache=None):
745
+ self.latents_cache = latents_cache
746
+ self.text_encoder_cache = text_encoder_cache
747
+ self.mask_cache = mask_cache
748
+ self.mask_mean_cache = mask_mean_cache
749
+ self.extra_cache = extra_cache
750
+ self.tokens_cache = tokens_cache
751
+ def add_latent(self, latent, text_encoder, cached_mask, cached_extra, tokens_cache):
752
+ self.latents_cache.append(latent)
753
+ self.text_encoder_cache.append(text_encoder)
754
+ self.mask_cache.append(cached_mask)
755
+ self.mask_mean_cache.append(None if cached_mask is None else cached_mask.mean())
756
+ self.extra_cache.append(cached_extra)
757
+ self.tokens_cache.append(tokens_cache)
758
+ def __len__(self):
759
+ return len(self.latents_cache)
760
+ def __getitem__(self, index):
761
+ return self.latents_cache[index], self.text_encoder_cache[index], self.mask_cache[index], self.mask_mean_cache[index], self.extra_cache[index], self.tokens_cache[index]
762
+
763
+ class DataLoaderMultiAspect():
764
+ """
765
+ Data loader for multi-aspect-ratio training and bucketing
766
+ data_root: root folder of training data
767
+ batch_size: number of images per batch
768
+ flip_p: probability of flipping image horizontally (i.e. 0-0.5)
769
+ """
770
+ def __init__(
771
+ self,
772
+ concept_list,
773
+ seed=555,
774
+ debug_level=0,
775
+ resolution=512,
776
+ batch_size=1,
777
+ flip_p=0.0,
778
+ use_image_names_as_captions=True,
779
+ add_class_images_to_dataset=False,
780
+ balance_datasets=False,
781
+ with_prior_loss=False,
782
+ use_text_files_as_captions=False,
783
+ aspect_mode='dynamic',
784
+ action_preference='add',
785
+ model_variant='base',
786
+ extra_module=None,
787
+ mask_prompts=None,
788
+ load_mask=False,
789
+ ):
790
+ self.resolution = resolution
791
+ self.debug_level = debug_level
792
+ self.flip_p = flip_p
793
+ self.use_image_names_as_captions = use_image_names_as_captions
794
+ self.balance_datasets = balance_datasets
795
+ self.with_prior_loss = with_prior_loss
796
+ self.add_class_images_to_dataset = add_class_images_to_dataset
797
+ self.use_text_files_as_captions = use_text_files_as_captions
798
+ self.aspect_mode = aspect_mode
799
+ self.action_preference = action_preference
800
+ self.seed = seed
801
+ self.model_variant = model_variant
802
+ self.extra_module = extra_module
803
+ self.load_mask = load_mask
804
+ prepared_train_data = []
805
+
806
+ self.aspects = get_aspect_buckets(resolution)
807
+ #print(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
808
+ #process sub directories flag
809
+
810
+ print(f" {bcolors.WARNING} Preloading images...{bcolors.ENDC}")
811
+
812
+ if balance_datasets:
813
+ print(f" {bcolors.WARNING} Balancing datasets...{bcolors.ENDC}")
814
+ #get the concept with the least number of images in instance_data_dir
815
+ for concept in concept_list:
816
+ count = 0
817
+ if 'use_sub_dirs' in concept:
818
+ if concept['use_sub_dirs'] == 1:
819
+ tot = 0
820
+ for root, dirs, files in os.walk(concept['instance_data_dir']):
821
+ tot += len(files)
822
+ count = tot
823
+ else:
824
+ count = len(os.listdir(concept['instance_data_dir']))
825
+ else:
826
+ count = len(os.listdir(concept['instance_data_dir']))
827
+ print(f"{concept['instance_data_dir']} has count of {count}")
828
+ concept['count'] = count
829
+
830
+ min_concept = min(concept_list, key=lambda x: x['count'])
831
+ #get the number of images in the concept with the least number of images
832
+ min_concept_num_images = min_concept['count']
833
+ print(" Min concept: ",min_concept['instance_data_dir']," with ",min_concept_num_images," images")
834
+
835
+ balance_cocnept_list = []
836
+ for concept in concept_list:
837
+ #if concept has a key do not balance it
838
+ if 'do_not_balance' in concept:
839
+ if concept['do_not_balance'] == True:
840
+ balance_cocnept_list.append(-1)
841
+ else:
842
+ balance_cocnept_list.append(min_concept_num_images)
843
+ else:
844
+ balance_cocnept_list.append(min_concept_num_images)
845
+ for concept in concept_list:
846
+ if 'use_sub_dirs' in concept:
847
+ if concept['use_sub_dirs'] == True:
848
+ use_sub_dirs = True
849
+ else:
850
+ use_sub_dirs = False
851
+ else:
852
+ use_sub_dirs = False
853
+ self.image_paths = []
854
+ #self.class_image_paths = []
855
+ min_concept_num_images = None
856
+ if balance_datasets:
857
+ min_concept_num_images = balance_cocnept_list[concept_list.index(concept)]
858
+ data_root = concept['instance_data_dir']
859
+ data_root_class = concept['class_data_dir']
860
+ concept_prompt = concept['instance_prompt']
861
+ concept_class_prompt = concept['class_prompt']
862
+ if 'flip_p' in concept.keys():
863
+ flip_p = concept['flip_p']
864
+ if flip_p == '':
865
+ flip_p = 0.0
866
+ else:
867
+ flip_p = float(flip_p)
868
+ self.__recurse_data_root(self=self, recurse_root=data_root,use_sub_dirs=use_sub_dirs)
869
+ random.Random(self.seed).shuffle(self.image_paths)
870
+ if self.model_variant == 'depth2img':
871
+ print(f" {bcolors.WARNING} ** Loading Depth2Img Pipeline To Process Dataset{bcolors.ENDC}")
872
+ self.vae_scale_factor = self.extra_module.depth_images(self.image_paths)
873
+ prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_prompt,use_text_files_as_captions=self.use_text_files_as_captions)[0:min_concept_num_images]) # ImageTrainItem[]
874
+ if add_class_images_to_dataset:
875
+ self.image_paths = []
876
+ self.__recurse_data_root(self=self, recurse_root=data_root_class,use_sub_dirs=use_sub_dirs)
877
+ random.Random(self.seed).shuffle(self.image_paths)
878
+ use_image_names_as_captions = False
879
+ prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_class_prompt,use_text_files_as_captions=self.use_text_files_as_captions)) # ImageTrainItem[]
880
+
881
+ self.image_caption_pairs = self.__bucketize_images(prepared_train_data, batch_size=batch_size, debug_level=debug_level,aspect_mode=self.aspect_mode,action_preference=self.action_preference)
882
+ if self.with_prior_loss and add_class_images_to_dataset == False:
883
+ self.class_image_caption_pairs = []
884
+ for concept in concept_list:
885
+ self.class_images_path = []
886
+ data_root_class = concept['class_data_dir']
887
+ concept_class_prompt = concept['class_prompt']
888
+ self.__recurse_data_root(self=self, recurse_root=data_root_class,use_sub_dirs=use_sub_dirs,class_images=True)
889
+ random.Random(seed).shuffle(self.image_paths)
890
+ if self.model_variant == 'depth2img':
891
+ print(f" {bcolors.WARNING} ** Depth2Img To Process Class Dataset{bcolors.ENDC}")
892
+ self.vae_scale_factor = self.extra_module.depth_images(self.image_paths)
893
+ use_image_names_as_captions = False
894
+ self.class_image_caption_pairs.extend(self.__prescan_images(debug_level, self.class_images_path, flip_p,use_image_names_as_captions,concept_class_prompt,use_text_files_as_captions=self.use_text_files_as_captions))
895
+ self.class_image_caption_pairs = self.__bucketize_images(self.class_image_caption_pairs, batch_size=batch_size, debug_level=debug_level,aspect_mode=self.aspect_mode,action_preference=self.action_preference)
896
+ if mask_prompts is not None:
897
+ print(f" {bcolors.WARNING} Checking and generating missing masks...{bcolors.ENDC}")
898
+ clip_seg = ClipSeg()
899
+ clip_seg.mask_images(self.image_paths, mask_prompts)
900
+ del clip_seg
901
+ if debug_level > 0: print(f" * DLMA Example: {self.image_caption_pairs[0]} images")
902
+ #print the length of image_caption_pairs
903
+ print(f" {bcolors.WARNING} Number of image-caption pairs: {len(self.image_caption_pairs)}{bcolors.ENDC}")
904
+ if len(self.image_caption_pairs) == 0:
905
+ raise Exception("All the buckets are empty. Please check your data or reduce the batch size.")
906
+ def get_all_images(self):
907
+ if self.with_prior_loss == False:
908
+ return self.image_caption_pairs
909
+ else:
910
+ return self.image_caption_pairs, self.class_image_caption_pairs
911
+ def __prescan_images(self,debug_level: int, image_paths: list, flip_p=0.0,use_image_names_as_captions=True,concept=None,use_text_files_as_captions=False):
912
+ """
913
+ Create ImageTrainItem objects with metadata for hydration later
914
+ """
915
+ decorated_image_train_items = []
916
+
917
+ for pathname in image_paths:
918
+ identifier = concept
919
+ if use_image_names_as_captions:
920
+ caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
921
+ identifier = caption_from_filename
922
+ if use_text_files_as_captions:
923
+ txt_file_path = os.path.splitext(pathname)[0] + ".txt"
924
+
925
+ if os.path.exists(txt_file_path):
926
+ try:
927
+ with open(txt_file_path, 'r',encoding='utf-8',errors='ignore') as f:
928
+ identifier = f.readline().rstrip()
929
+ f.close()
930
+ if len(identifier) < 1:
931
+ raise ValueError(f" *** Could not find valid text in: {txt_file_path}")
932
+
933
+ except Exception as e:
934
+ print(f" {bcolors.FAIL} *** Error reading {txt_file_path} to get caption, falling back to filename{bcolors.ENDC}")
935
+ print(e)
936
+ identifier = caption_from_filename
937
+ pass
938
+ #print("identifier: ",identifier)
939
+ image = Image.open(pathname)
940
+ width, height = image.size
941
+ image_aspect = width / height
942
+
943
+ target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
944
+
945
+ image_train_item = ImageTrainItem(image=None, mask=None, extra=None, caption=identifier, target_wh=target_wh, pathname=pathname, flip_p=flip_p,model_variant=self.model_variant, load_mask=self.load_mask)
946
+
947
+ decorated_image_train_items.append(image_train_item)
948
+ return decorated_image_train_items
949
+
950
+ @staticmethod
951
+ def __bucketize_images(prepared_train_data: list, batch_size=1, debug_level=0,aspect_mode='dynamic',action_preference='add'):
952
+ """
953
+ Put images into buckets based on aspect ratio with batch_size*n images per bucket, discards remainder
954
+ """
955
+
956
+ # TODO: this is not terribly efficient but at least linear time
957
+ buckets = {}
958
+ for image_caption_pair in prepared_train_data:
959
+ target_wh = image_caption_pair.target_wh
960
+
961
+ if (target_wh[0],target_wh[1]) not in buckets:
962
+ buckets[(target_wh[0],target_wh[1])] = []
963
+ buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
964
+ print(f" ** Number of buckets: {len(buckets)}")
965
+ for bucket in buckets:
966
+ bucket_len = len(buckets[bucket])
967
+ #real_len = len(buckets[bucket])+1
968
+ #print(real_len)
969
+ truncate_amount = bucket_len % batch_size
970
+ add_amount = batch_size - bucket_len % batch_size
971
+ action = None
972
+ #print(f" ** Bucket {bucket} has {bucket_len} images")
973
+ if aspect_mode == 'dynamic':
974
+ if batch_size == bucket_len:
975
+ action = None
976
+ elif add_amount < truncate_amount and add_amount != 0 and add_amount != batch_size or truncate_amount == 0:
977
+ action = 'add'
978
+ #print(f'should add {add_amount}')
979
+ elif truncate_amount < add_amount and truncate_amount != 0 and truncate_amount != batch_size and batch_size < bucket_len:
980
+ #print(f'should truncate {truncate_amount}')
981
+ action = 'truncate'
982
+ #truncate the bucket
983
+ elif truncate_amount == add_amount:
984
+ if action_preference == 'add':
985
+ action = 'add'
986
+ elif action_preference == 'truncate':
987
+ action = 'truncate'
988
+ elif batch_size > bucket_len:
989
+ action = 'add'
990
+
991
+ elif aspect_mode == 'add':
992
+ action = 'add'
993
+ elif aspect_mode == 'truncate':
994
+ action = 'truncate'
995
+ if action == None:
996
+ action = None
997
+ #print('no need to add or truncate')
998
+ if action == None:
999
+ #print('test')
1000
+ current_bucket_size = bucket_len
1001
+ print(f" ** Bucket {bucket} found {bucket_len}, nice!")
1002
+ elif action == 'add':
1003
+ #copy the bucket
1004
+ shuffleBucket = random.sample(buckets[bucket], bucket_len)
1005
+ #add the images to the bucket
1006
+ current_bucket_size = bucket_len
1007
+ truncate_count = (bucket_len) % batch_size
1008
+ #how many images to add to the bucket to fill the batch
1009
+ addAmount = batch_size - truncate_count
1010
+ if addAmount != batch_size:
1011
+ added=0
1012
+ while added != addAmount:
1013
+ randomIndex = random.randint(0,len(shuffleBucket)-1)
1014
+ #print(str(randomIndex))
1015
+ buckets[bucket].append(shuffleBucket[randomIndex])
1016
+ added+=1
1017
+ print(f" ** Bucket {bucket} found {bucket_len} images, will {bcolors.OKCYAN}duplicate {added} images{bcolors.ENDC} due to batch size {bcolors.WARNING}{batch_size}{bcolors.ENDC}")
1018
+ else:
1019
+ print(f" ** Bucket {bucket} found {bucket_len}, {bcolors.OKGREEN}nice!{bcolors.ENDC}")
1020
+ elif action == 'truncate':
1021
+ truncate_count = (bucket_len) % batch_size
1022
+ current_bucket_size = bucket_len
1023
+ buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
1024
+ print(f" ** Bucket {bucket} found {bucket_len} images, will {bcolors.FAIL}drop {truncate_count} images{bcolors.ENDC} due to batch size {bcolors.WARNING}{batch_size}{bcolors.ENDC}")
1025
+
1026
+
1027
+ # flatten the buckets
1028
+ image_caption_pairs = []
1029
+ for bucket in buckets:
1030
+ image_caption_pairs.extend(buckets[bucket])
1031
+
1032
+ return image_caption_pairs
1033
+
1034
+ @staticmethod
1035
+ def __recurse_data_root(self, recurse_root,use_sub_dirs=True,class_images=False):
1036
+ progress_bar = tqdm(os.listdir(recurse_root), desc=f" {bcolors.WARNING} ** Processing {recurse_root}{bcolors.ENDC}")
1037
+ for f in os.listdir(recurse_root):
1038
+ current = os.path.join(recurse_root, f)
1039
+ if os.path.isfile(current):
1040
+ ext = os.path.splitext(f)[1].lower()
1041
+ if '-depth' in f or '-masklabel' in f:
1042
+ progress_bar.update(1)
1043
+ continue
1044
+ if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp']:
1045
+ #try to open the file to make sure it's a valid image
1046
+ try:
1047
+ img = Image.open(current)
1048
+ except:
1049
+ print(f" ** Skipping {current} because it failed to open, please check the file")
1050
+ progress_bar.update(1)
1051
+ continue
1052
+ del img
1053
+ if class_images == False:
1054
+ self.image_paths.append(current)
1055
+ else:
1056
+ self.class_images_path.append(current)
1057
+ progress_bar.update(1)
1058
+ if use_sub_dirs:
1059
+ sub_dirs = []
1060
+
1061
+ for d in os.listdir(recurse_root):
1062
+ current = os.path.join(recurse_root, d)
1063
+ if os.path.isdir(current):
1064
+ sub_dirs.append(current)
1065
+
1066
+ for dir in sub_dirs:
1067
+ self.__recurse_data_root(self=self, recurse_root=dir)
1068
+
1069
+ class NormalDataset(Dataset):
1070
+ """
1071
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
1072
+ It pre-processes the images and the tokenizes prompts.
1073
+ """
1074
+
1075
+ def __init__(
1076
+ self,
1077
+ concepts_list,
1078
+ tokenizer,
1079
+ with_prior_preservation=True,
1080
+ size=512,
1081
+ center_crop=False,
1082
+ num_class_images=None,
1083
+ use_image_names_as_captions=False,
1084
+ shuffle_captions=False,
1085
+ repeats=1,
1086
+ use_text_files_as_captions=False,
1087
+ seed=555,
1088
+ model_variant='base',
1089
+ extra_module=None,
1090
+ mask_prompts=None,
1091
+ load_mask=None,
1092
+ ):
1093
+ self.use_image_names_as_captions = use_image_names_as_captions
1094
+ self.shuffle_captions = shuffle_captions
1095
+ self.size = size
1096
+ self.center_crop = center_crop
1097
+ self.tokenizer = tokenizer
1098
+ self.with_prior_preservation = with_prior_preservation
1099
+ self.use_text_files_as_captions = use_text_files_as_captions
1100
+ self.image_paths = []
1101
+ self.class_images_path = []
1102
+ self.seed = seed
1103
+ self.model_variant = model_variant
1104
+ self.variant_warning = False
1105
+ self.vae_scale_factor = None
1106
+ self.load_mask = load_mask
1107
+ for concept in concepts_list:
1108
+ if 'use_sub_dirs' in concept:
1109
+ if concept['use_sub_dirs'] == True:
1110
+ use_sub_dirs = True
1111
+ else:
1112
+ use_sub_dirs = False
1113
+ else:
1114
+ use_sub_dirs = False
1115
+
1116
+ for i in range(repeats):
1117
+ self.__recurse_data_root(self, concept,use_sub_dirs=use_sub_dirs)
1118
+
1119
+ if with_prior_preservation:
1120
+ for i in range(repeats):
1121
+ self.__recurse_data_root(self, concept,use_sub_dirs=False,class_images=True)
1122
+ if mask_prompts is not None:
1123
+ print(f" {bcolors.WARNING} Checking and generating missing masks{bcolors.ENDC}")
1124
+ clip_seg = ClipSeg()
1125
+ clip_seg.mask_images(self.image_paths, mask_prompts)
1126
+ del clip_seg
1127
+
1128
+ random.Random(seed).shuffle(self.image_paths)
1129
+ self.num_instance_images = len(self.image_paths)
1130
+ self._length = self.num_instance_images
1131
+ self.num_class_images = len(self.class_images_path)
1132
+ self._length = max(self.num_class_images, self.num_instance_images)
1133
+ if self.model_variant == 'depth2img':
1134
+ print(f" {bcolors.WARNING} ** Loading Depth2Img Pipeline To Process Dataset{bcolors.ENDC}")
1135
+ self.vae_scale_factor = extra_module.depth_images(self.image_paths)
1136
+ if self.with_prior_preservation:
1137
+ print(f" {bcolors.WARNING} ** Loading Depth2Img Class Processing{bcolors.ENDC}")
1138
+ extra_module.depth_images(self.class_images_path)
1139
+ print(f" {bcolors.WARNING} ** Dataset length: {self._length}, {int(self.num_instance_images / repeats)} images using {repeats} repeats{bcolors.ENDC}")
1140
+
1141
+ self.image_transforms = transforms.Compose(
1142
+ [
1143
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
1144
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
1145
+ transforms.ToTensor(),
1146
+ transforms.Normalize([0.5], [0.5]),
1147
+ ]
1148
+
1149
+ )
1150
+ self.mask_transforms = transforms.Compose(
1151
+ [
1152
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
1153
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
1154
+ transforms.ToTensor(),
1155
+ ])
1156
+
1157
+ self.depth_image_transforms = transforms.Compose(
1158
+ [
1159
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
1160
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
1161
+ transforms.ToTensor(),
1162
+ ]
1163
+ )
1164
+
1165
+ @staticmethod
1166
+ def __recurse_data_root(self, recurse_root,use_sub_dirs=True,class_images=False):
1167
+ #if recurse root is a dict
1168
+ if isinstance(recurse_root, dict):
1169
+ if class_images == True:
1170
+ #print(f" {bcolors.WARNING} ** Processing class images: {recurse_root['class_data_dir']}{bcolors.ENDC}")
1171
+ concept_token = recurse_root['class_prompt']
1172
+ data = recurse_root['class_data_dir']
1173
+ else:
1174
+ #print(f" {bcolors.WARNING} ** Processing instance images: {recurse_root['instance_data_dir']}{bcolors.ENDC}")
1175
+ concept_token = recurse_root['instance_prompt']
1176
+ data = recurse_root['instance_data_dir']
1177
+
1178
+
1179
+ else:
1180
+ concept_token = None
1181
+ #progress bar
1182
+ progress_bar = tqdm(os.listdir(data), desc=f" {bcolors.WARNING} ** Processing {data}{bcolors.ENDC}")
1183
+ for f in os.listdir(data):
1184
+ current = os.path.join(data, f)
1185
+ if os.path.isfile(current):
1186
+ if '-depth' in f or '-masklabel' in f:
1187
+ continue
1188
+ ext = os.path.splitext(f)[1].lower()
1189
+ if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp']:
1190
+ try:
1191
+ img = Image.open(current)
1192
+ except:
1193
+ print(f" ** Skipping {current} because it failed to open, please check the file")
1194
+ progress_bar.update(1)
1195
+ continue
1196
+ del img
1197
+ if class_images == False:
1198
+ self.image_paths.append([current,concept_token])
1199
+ else:
1200
+ self.class_images_path.append([current,concept_token])
1201
+ progress_bar.update(1)
1202
+ if use_sub_dirs:
1203
+ sub_dirs = []
1204
+
1205
+ for d in os.listdir(data):
1206
+ current = os.path.join(data, d)
1207
+ if os.path.isdir(current):
1208
+ sub_dirs.append(current)
1209
+
1210
+ for dir in sub_dirs:
1211
+ if class_images == False:
1212
+ self.__recurse_data_root(self=self, recurse_root={'instance_data_dir' : dir, 'instance_prompt' : concept_token})
1213
+ else:
1214
+ self.__recurse_data_root(self=self, recurse_root={'class_data_dir' : dir, 'class_prompt' : concept_token})
1215
+
1216
+ def __len__(self):
1217
+ return self._length
1218
+
1219
+ def __getitem__(self, index):
1220
+ example = {}
1221
+ instance_path, instance_prompt = self.image_paths[index % self.num_instance_images]
1222
+ og_prompt = instance_prompt
1223
+ instance_image = Image.open(instance_path)
1224
+ if self.model_variant == "inpainting" or self.load_mask:
1225
+
1226
+ mask_pathname = os.path.splitext(instance_path)[0] + "-masklabel.png"
1227
+ if os.path.exists(mask_pathname) and self.load_mask:
1228
+ mask = Image.open(mask_pathname).convert("L")
1229
+ else:
1230
+ if self.variant_warning == False:
1231
+ print(f" {bcolors.FAIL} ** Warning: No mask found for an image, using an empty mask but make sure you're training the right model variant.{bcolors.ENDC}")
1232
+ self.variant_warning = True
1233
+ size = instance_image.size
1234
+ mask = Image.new('RGB', size, color="white").convert("L")
1235
+ example["mask"] = self.mask_transforms(mask)
1236
+ if self.model_variant == "depth2img":
1237
+ depth_pathname = os.path.splitext(instance_path)[0] + "-depth.png"
1238
+ if os.path.exists(depth_pathname):
1239
+ depth_image = Image.open(depth_pathname).convert("L")
1240
+ else:
1241
+ if self.variant_warning == False:
1242
+ print(f" {bcolors.FAIL} ** Warning: No depth image found for an image, using an empty depth image but make sure you're training the right model variant.{bcolors.ENDC}")
1243
+ self.variant_warning = True
1244
+ size = instance_image.size
1245
+ depth_image = Image.new('RGB', size, color="white").convert("L")
1246
+ example["instance_depth_images"] = self.depth_image_transforms(depth_image)
1247
+
1248
+ if self.use_image_names_as_captions == True:
1249
+ instance_prompt = str(instance_path).split(os.sep)[-1].split('.')[0].split('_')[0]
1250
+ #else if there's a txt file with the same name as the image, read the caption from there
1251
+ if self.use_text_files_as_captions == True:
1252
+ #if there's a file with the same name as the image, but with a .txt extension, read the caption from there
1253
+ #get the last . in the file name
1254
+ last_dot = str(instance_path).rfind('.')
1255
+ #get the path up to the last dot
1256
+ txt_path = str(instance_path)[:last_dot] + '.txt'
1257
+
1258
+ #if txt_path exists, read the caption from there
1259
+ if os.path.exists(txt_path):
1260
+ with open(txt_path, encoding='utf-8') as f:
1261
+ instance_prompt = f.readline().rstrip()
1262
+ f.close()
1263
+
1264
+ if self.shuffle_captions:
1265
+ caption_parts = instance_prompt.split(",")
1266
+ random.shuffle(caption_parts)
1267
+ instance_prompt = ",".join(caption_parts)
1268
+
1269
+ #print('identifier: ' + instance_prompt)
1270
+ instance_image = instance_image.convert("RGB")
1271
+ example["instance_images"] = self.image_transforms(instance_image)
1272
+ example["instance_prompt_ids"] = self.tokenizer(
1273
+ instance_prompt,
1274
+ padding="do_not_pad",
1275
+ truncation=True,
1276
+ max_length=self.tokenizer.model_max_length,
1277
+ ).input_ids
1278
+ if self.with_prior_preservation:
1279
+ class_path, class_prompt = self.class_images_path[index % self.num_class_images]
1280
+ class_image = Image.open(class_path)
1281
+ if not class_image.mode == "RGB":
1282
+ class_image = class_image.convert("RGB")
1283
+
1284
+ if self.model_variant == "inpainting":
1285
+ mask_pathname = os.path.splitext(class_path)[0] + "-masklabel.png"
1286
+ if os.path.exists(mask_pathname):
1287
+ mask = Image.open(mask_pathname).convert("L")
1288
+ else:
1289
+ if self.variant_warning == False:
1290
+ print(f" {bcolors.FAIL} ** Warning: No mask found for an image, using an empty mask but make sure you're training the right model variant.{bcolors.ENDC}")
1291
+ self.variant_warning = True
1292
+ size = instance_image.size
1293
+ mask = Image.new('RGB', size, color="white").convert("L")
1294
+ example["class_mask"] = self.mask_transforms(mask)
1295
+ if self.model_variant == "depth2img":
1296
+ depth_pathname = os.path.splitext(class_path)[0] + "-depth.png"
1297
+ if os.path.exists(depth_pathname):
1298
+ depth_image = Image.open(depth_pathname)
1299
+ else:
1300
+ if self.variant_warning == False:
1301
+ print(f" {bcolors.FAIL} ** Warning: No depth image found for an image, using an empty depth image but make sure you're training the right model variant.{bcolors.ENDC}")
1302
+ self.variant_warning = True
1303
+ size = instance_image.size
1304
+ depth_image = Image.new('RGB', size, color="white").convert("L")
1305
+ example["class_depth_images"] = self.depth_image_transforms(depth_image)
1306
+ example["class_images"] = self.image_transforms(class_image)
1307
+ example["class_prompt_ids"] = self.tokenizer(
1308
+ class_prompt,
1309
+ padding="do_not_pad",
1310
+ truncation=True,
1311
+ max_length=self.tokenizer.model_max_length,
1312
+ ).input_ids
1313
+
1314
+ return example
1315
+
1316
+
1317
+ class PromptDataset(Dataset):
1318
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
1319
+
1320
+ def __init__(self, prompt, num_samples):
1321
+ self.prompt = prompt
1322
+ self.num_samples = num_samples
1323
+
1324
+ def __len__(self):
1325
+ return self.num_samples
1326
+
1327
+ def __getitem__(self, index):
1328
+ example = {}
1329
+ example["prompt"] = self.prompt
1330
+ example["index"] = index
1331
+ return example
StableTuner_RunPod_Fix/discriminator.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import einops, einops.layers.torch
6
+ import diffusers
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from typing import Tuple, Optional
9
+
10
+ import inspect
11
+ import os
12
+ from functools import partial
13
+ from typing import Callable, List, Optional, Tuple, Union
14
+
15
+ import torch
16
+ from torch import Tensor, device
17
+
18
+
19
+
20
+ class ModelMixin(torch.nn.Module):
21
+ r"""
22
+ Base class for all models.
23
+ [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
24
+ and saving models.
25
+ - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
26
+ [`~models.ModelMixin.save_pretrained`].
27
+ """
28
+ config_name = "new"
29
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
30
+ _supports_gradient_checkpointing = False
31
+
32
+ def __init__(self):
33
+ super().__init__()
34
+
35
+ @property
36
+ def is_gradient_checkpointing(self) -> bool:
37
+ """
38
+ Whether gradient checkpointing is activated for this model or not.
39
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
40
+ activations".
41
+ """
42
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
43
+
44
+ def enable_gradient_checkpointing(self):
45
+ """
46
+ Activates gradient checkpointing for the current model.
47
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
48
+ activations".
49
+ """
50
+ if not self._supports_gradient_checkpointing:
51
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
52
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
53
+
54
+ def disable_gradient_checkpointing(self):
55
+ """
56
+ Deactivates gradient checkpointing for the current model.
57
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
58
+ activations".
59
+ """
60
+ if self._supports_gradient_checkpointing:
61
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
62
+
63
+ def set_use_memory_efficient_attention_xformers(
64
+ self, valid: bool, attention_op: Optional[Callable] = None
65
+ ) -> None:
66
+ # Recursively walk through all the children.
67
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
68
+ # gets the message
69
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
70
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
71
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
72
+
73
+ for child in module.children():
74
+ fn_recursive_set_mem_eff(child)
75
+
76
+ for module in self.children():
77
+ if isinstance(module, torch.nn.Module):
78
+ fn_recursive_set_mem_eff(module)
79
+
80
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
81
+ r"""
82
+ Enable memory efficient attention as implemented in xformers.
83
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
84
+ time. Speed up at training time is not guaranteed.
85
+ Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
86
+ is used.
87
+ Parameters:
88
+ attention_op (`Callable`, *optional*):
89
+ Override the default `None` operator for use as `op` argument to the
90
+ [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
91
+ function of xFormers.
92
+ Examples:
93
+ ```py
94
+ >>> import torch
95
+ >>> from diffusers import UNet2DConditionModel
96
+ >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
97
+ >>> model = UNet2DConditionModel.from_pretrained(
98
+ ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
99
+ ... )
100
+ >>> model = model.to("cuda")
101
+ >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
102
+ ```
103
+ """
104
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
105
+
106
+ def disable_xformers_memory_efficient_attention(self):
107
+ r"""
108
+ Disable memory efficient attention as implemented in xformers.
109
+ """
110
+ self.set_use_memory_efficient_attention_xformers(False)
111
+
112
+ def save_pretrained(
113
+ self,
114
+ save_directory: Union[str, os.PathLike],
115
+ is_main_process: bool = True,
116
+ save_function: Callable = None,
117
+ safe_serialization: bool = False,
118
+ variant: Optional[str] = None,
119
+ ):
120
+ """
121
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
122
+ `[`~models.ModelMixin.from_pretrained`]` class method.
123
+ Arguments:
124
+ save_directory (`str` or `os.PathLike`):
125
+ Directory to which to save. Will be created if it doesn't exist.
126
+ is_main_process (`bool`, *optional*, defaults to `True`):
127
+ Whether the process calling this is the main process or not. Useful when in distributed training like
128
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
129
+ the main process to avoid race conditions.
130
+ save_function (`Callable`):
131
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
132
+ need to replace `torch.save` by another method. Can be configured with the environment variable
133
+ `DIFFUSERS_SAVE_MODE`.
134
+ safe_serialization (`bool`, *optional*, defaults to `False`):
135
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
136
+ variant (`str`, *optional*):
137
+ If specified, weights are saved in the format pytorch_model.<variant>.bin.
138
+ """
139
+ if safe_serialization and not is_safetensors_available():
140
+ raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
141
+
142
+ if os.path.isfile(save_directory):
143
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
144
+ return
145
+
146
+ os.makedirs(save_directory, exist_ok=True)
147
+
148
+ model_to_save = self
149
+
150
+ # Attach architecture to the config
151
+ # Save the config
152
+ if is_main_process:
153
+ model_to_save.save_config(save_directory)
154
+
155
+ # Save the model
156
+ state_dict = model_to_save.state_dict()
157
+
158
+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
159
+ weights_name = _add_variant(weights_name, variant)
160
+
161
+ # Save the model
162
+ if safe_serialization:
163
+ safetensors.torch.save_file(
164
+ state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
165
+ )
166
+ else:
167
+ torch.save(state_dict, os.path.join(save_directory, weights_name))
168
+
169
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
170
+
171
+ @classmethod
172
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
173
+ r"""
174
+ Instantiate a pretrained pytorch model from a pre-trained model configuration.
175
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
176
+ the model, you should first set it back in training mode with `model.train()`.
177
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
178
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
179
+ task.
180
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
181
+ weights are discarded.
182
+ Parameters:
183
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
184
+ Can be either:
185
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
186
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
187
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
188
+ `./my_model_directory/`.
189
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
190
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
191
+ standard cache should not be used.
192
+ torch_dtype (`str` or `torch.dtype`, *optional*):
193
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
194
+ will be automatically derived from the model's weights.
195
+ force_download (`bool`, *optional*, defaults to `False`):
196
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
197
+ cached versions if they exist.
198
+ resume_download (`bool`, *optional*, defaults to `False`):
199
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
200
+ file exists.
201
+ proxies (`Dict[str, str]`, *optional*):
202
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
203
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
204
+ output_loading_info(`bool`, *optional*, defaults to `False`):
205
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
206
+ local_files_only(`bool`, *optional*, defaults to `False`):
207
+ Whether or not to only look at local files (i.e., do not try to download the model).
208
+ use_auth_token (`str` or *bool*, *optional*):
209
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
210
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
211
+ revision (`str`, *optional*, defaults to `"main"`):
212
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
213
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
214
+ identifier allowed by git.
215
+ from_flax (`bool`, *optional*, defaults to `False`):
216
+ Load the model weights from a Flax checkpoint save file.
217
+ subfolder (`str`, *optional*, defaults to `""`):
218
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
219
+ huggingface.co or downloaded locally), you can specify the folder name here.
220
+ mirror (`str`, *optional*):
221
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
222
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
223
+ Please refer to the mirror site for more information.
224
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
225
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
226
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
227
+ same device.
228
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
229
+ more information about each option see [designing a device
230
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
231
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
232
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
233
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
234
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
235
+ setting this argument to `True` will raise an error.
236
+ variant (`str`, *optional*):
237
+ If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
238
+ ignored when using `from_flax`.
239
+ use_safetensors (`bool`, *optional* ):
240
+ If set to `True`, the pipeline will forcibly load the models from `safetensors` weights. If set to
241
+ `None` (the default). The pipeline will load using `safetensors` if safetensors weights are available
242
+ *and* if `safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
243
+ <Tip>
244
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
245
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
246
+ </Tip>
247
+ <Tip>
248
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
249
+ this method in a firewalled environment.
250
+ </Tip>
251
+ """
252
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
253
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
254
+ force_download = kwargs.pop("force_download", False)
255
+ from_flax = kwargs.pop("from_flax", False)
256
+ resume_download = kwargs.pop("resume_download", False)
257
+ proxies = kwargs.pop("proxies", None)
258
+ output_loading_info = kwargs.pop("output_loading_info", False)
259
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
260
+ use_auth_token = kwargs.pop("use_auth_token", None)
261
+ revision = kwargs.pop("revision", None)
262
+ torch_dtype = kwargs.pop("torch_dtype", None)
263
+ subfolder = kwargs.pop("subfolder", None)
264
+ device_map = kwargs.pop("device_map", None)
265
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
266
+ variant = kwargs.pop("variant", None)
267
+ use_safetensors = kwargs.pop("use_safetensors", None)
268
+
269
+ if use_safetensors and not is_safetensors_available():
270
+ raise ValueError(
271
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
272
+ )
273
+
274
+ allow_pickle = False
275
+ if use_safetensors is None:
276
+ use_safetensors = is_safetensors_available()
277
+ allow_pickle = True
278
+
279
+ if low_cpu_mem_usage and not is_accelerate_available():
280
+ low_cpu_mem_usage = False
281
+ logger.warning(
282
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
283
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
284
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
285
+ " install accelerate\n```\n."
286
+ )
287
+
288
+ if device_map is not None and not is_accelerate_available():
289
+ raise NotImplementedError(
290
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
291
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
292
+ )
293
+
294
+ # Check if we can handle device_map and dispatching the weights
295
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
296
+ raise NotImplementedError(
297
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
298
+ " `device_map=None`."
299
+ )
300
+
301
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
302
+ raise NotImplementedError(
303
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
304
+ " `low_cpu_mem_usage=False`."
305
+ )
306
+
307
+ if low_cpu_mem_usage is False and device_map is not None:
308
+ raise ValueError(
309
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
310
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
311
+ )
312
+
313
+ # Load config if we don't provide a configuration
314
+ config_path = pretrained_model_name_or_path
315
+
316
+ user_agent = {
317
+ "diffusers": __version__,
318
+ "file_type": "model",
319
+ "framework": "pytorch",
320
+ }
321
+
322
+ # load config
323
+ config, unused_kwargs, commit_hash = cls.load_config(
324
+ config_path,
325
+ cache_dir=cache_dir,
326
+ return_unused_kwargs=True,
327
+ return_commit_hash=True,
328
+ force_download=force_download,
329
+ resume_download=resume_download,
330
+ proxies=proxies,
331
+ local_files_only=local_files_only,
332
+ use_auth_token=use_auth_token,
333
+ revision=revision,
334
+ subfolder=subfolder,
335
+ device_map=device_map,
336
+ user_agent=user_agent,
337
+ **kwargs,
338
+ )
339
+
340
+ # load model
341
+ model_file = None
342
+ if from_flax:
343
+ model_file = _get_model_file(
344
+ pretrained_model_name_or_path,
345
+ weights_name=FLAX_WEIGHTS_NAME,
346
+ cache_dir=cache_dir,
347
+ force_download=force_download,
348
+ resume_download=resume_download,
349
+ proxies=proxies,
350
+ local_files_only=local_files_only,
351
+ use_auth_token=use_auth_token,
352
+ revision=revision,
353
+ subfolder=subfolder,
354
+ user_agent=user_agent,
355
+ commit_hash=commit_hash,
356
+ )
357
+ model = cls.from_config(config, **unused_kwargs)
358
+
359
+ # Convert the weights
360
+ from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
361
+
362
+ model = load_flax_checkpoint_in_pytorch_model(model, model_file)
363
+ else:
364
+ if use_safetensors:
365
+ try:
366
+ model_file = _get_model_file(
367
+ pretrained_model_name_or_path,
368
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
369
+ cache_dir=cache_dir,
370
+ force_download=force_download,
371
+ resume_download=resume_download,
372
+ proxies=proxies,
373
+ local_files_only=local_files_only,
374
+ use_auth_token=use_auth_token,
375
+ revision=revision,
376
+ subfolder=subfolder,
377
+ user_agent=user_agent,
378
+ commit_hash=commit_hash,
379
+ )
380
+ except IOError as e:
381
+ if not allow_pickle:
382
+ raise e
383
+ pass
384
+ if model_file is None:
385
+ model_file = _get_model_file(
386
+ pretrained_model_name_or_path,
387
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
388
+ cache_dir=cache_dir,
389
+ force_download=force_download,
390
+ resume_download=resume_download,
391
+ proxies=proxies,
392
+ local_files_only=local_files_only,
393
+ use_auth_token=use_auth_token,
394
+ revision=revision,
395
+ subfolder=subfolder,
396
+ user_agent=user_agent,
397
+ commit_hash=commit_hash,
398
+ )
399
+
400
+ if low_cpu_mem_usage:
401
+ # Instantiate model with empty weights
402
+ with accelerate.init_empty_weights():
403
+ model = cls.from_config(config, **unused_kwargs)
404
+
405
+ # if device_map is None, load the state dict and move the params from meta device to the cpu
406
+ if device_map is None:
407
+ param_device = "cpu"
408
+ state_dict = load_state_dict(model_file, variant=variant)
409
+ # move the params from meta device to cpu
410
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
411
+ if len(missing_keys) > 0:
412
+ raise ValueError(
413
+ f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
414
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
415
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
416
+ " those weights or else make sure your checkpoint file is correct."
417
+ )
418
+
419
+ empty_state_dict = model.state_dict()
420
+ for param_name, param in state_dict.items():
421
+ accepts_dtype = "dtype" in set(
422
+ inspect.signature(set_module_tensor_to_device).parameters.keys()
423
+ )
424
+
425
+ if empty_state_dict[param_name].shape != param.shape:
426
+ raise ValueError(
427
+ f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
428
+ )
429
+
430
+ if accepts_dtype:
431
+ set_module_tensor_to_device(
432
+ model, param_name, param_device, value=param, dtype=torch_dtype
433
+ )
434
+ else:
435
+ set_module_tensor_to_device(model, param_name, param_device, value=param)
436
+ else: # else let accelerate handle loading and dispatching.
437
+ # Load weights and dispatch according to the device_map
438
+ # by default the device_map is None and the weights are loaded on the CPU
439
+ accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype)
440
+
441
+ loading_info = {
442
+ "missing_keys": [],
443
+ "unexpected_keys": [],
444
+ "mismatched_keys": [],
445
+ "error_msgs": [],
446
+ }
447
+ else:
448
+ model = cls.from_config(config, **unused_kwargs)
449
+
450
+ state_dict = load_state_dict(model_file, variant=variant)
451
+
452
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
453
+ model,
454
+ state_dict,
455
+ model_file,
456
+ pretrained_model_name_or_path,
457
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
458
+ )
459
+
460
+ loading_info = {
461
+ "missing_keys": missing_keys,
462
+ "unexpected_keys": unexpected_keys,
463
+ "mismatched_keys": mismatched_keys,
464
+ "error_msgs": error_msgs,
465
+ }
466
+
467
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
468
+ raise ValueError(
469
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
470
+ )
471
+ elif torch_dtype is not None:
472
+ model = model.to(torch_dtype)
473
+
474
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
475
+
476
+ # Set model in evaluation mode to deactivate DropOut modules by default
477
+ model.eval()
478
+ if output_loading_info:
479
+ return model, loading_info
480
+
481
+ return model
482
+
483
+ @classmethod
484
+ def _load_pretrained_model(
485
+ cls,
486
+ model,
487
+ state_dict,
488
+ resolved_archive_file,
489
+ pretrained_model_name_or_path,
490
+ ignore_mismatched_sizes=False,
491
+ ):
492
+ # Retrieve missing & unexpected_keys
493
+ model_state_dict = model.state_dict()
494
+ loaded_keys = list(state_dict.keys())
495
+
496
+ expected_keys = list(model_state_dict.keys())
497
+
498
+ original_loaded_keys = loaded_keys
499
+
500
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
501
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
502
+
503
+ # Make sure we are able to load base models as well as derived models (with heads)
504
+ model_to_load = model
505
+
506
+ def _find_mismatched_keys(
507
+ state_dict,
508
+ model_state_dict,
509
+ loaded_keys,
510
+ ignore_mismatched_sizes,
511
+ ):
512
+ mismatched_keys = []
513
+ if ignore_mismatched_sizes:
514
+ for checkpoint_key in loaded_keys:
515
+ model_key = checkpoint_key
516
+
517
+ if (
518
+ model_key in model_state_dict
519
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
520
+ ):
521
+ mismatched_keys.append(
522
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
523
+ )
524
+ del state_dict[checkpoint_key]
525
+ return mismatched_keys
526
+
527
+ if state_dict is not None:
528
+ # Whole checkpoint
529
+ mismatched_keys = _find_mismatched_keys(
530
+ state_dict,
531
+ model_state_dict,
532
+ original_loaded_keys,
533
+ ignore_mismatched_sizes,
534
+ )
535
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
536
+
537
+ if len(error_msgs) > 0:
538
+ error_msg = "\n\t".join(error_msgs)
539
+ if "size mismatch" in error_msg:
540
+ error_msg += (
541
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
542
+ )
543
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
544
+
545
+ if len(unexpected_keys) > 0:
546
+ logger.warning(
547
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
548
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
549
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
550
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
551
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
552
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
553
+ " identical (initializing a BertForSequenceClassification model from a"
554
+ " BertForSequenceClassification model)."
555
+ )
556
+ else:
557
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
558
+ if len(missing_keys) > 0:
559
+ logger.warning(
560
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
561
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
562
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
563
+ )
564
+ elif len(mismatched_keys) == 0:
565
+ logger.info(
566
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
567
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
568
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
569
+ " without further training."
570
+ )
571
+ if len(mismatched_keys) > 0:
572
+ mismatched_warning = "\n".join(
573
+ [
574
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
575
+ for key, shape1, shape2 in mismatched_keys
576
+ ]
577
+ )
578
+ logger.warning(
579
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
580
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
581
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
582
+ " able to use it for predictions and inference."
583
+ )
584
+
585
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
586
+
587
+ @property
588
+ def device(self) -> device:
589
+ """
590
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
591
+ device).
592
+ """
593
+ return get_parameter_device(self)
594
+
595
+ @property
596
+ def dtype(self) -> torch.dtype:
597
+ """
598
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
599
+ """
600
+ return get_parameter_dtype(self)
601
+
602
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
603
+ """
604
+ Get number of (optionally, trainable or non-embeddings) parameters in the module.
605
+ Args:
606
+ only_trainable (`bool`, *optional*, defaults to `False`):
607
+ Whether or not to return only the number of trainable parameters
608
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
609
+ Whether or not to return only the number of non-embeddings parameters
610
+ Returns:
611
+ `int`: The number of parameters.
612
+ """
613
+
614
+ if exclude_embeddings:
615
+ embedding_param_names = [
616
+ f"{name}.weight"
617
+ for name, module_type in self.named_modules()
618
+ if isinstance(module_type, torch.nn.Embedding)
619
+ ]
620
+ non_embedding_parameters = [
621
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
622
+ ]
623
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
624
+ else:
625
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
626
+
627
+ def Downsample(dim, dim_out):
628
+ return nn.Conv2d(dim, dim_out, 4, 2, 1)
629
+
630
+ class Residual(nn.Sequential):
631
+ def forward(self, input):
632
+ x = input
633
+ for module in self:
634
+ x = module(x)
635
+ return x + input
636
+
637
+ def ConvLayer(dim, dim_out, *, kernel_size=3, groups=32):
638
+ return nn.Sequential(
639
+ nn.GroupNorm(groups, dim),
640
+ nn.SiLU(),
641
+ nn.Conv2d(dim, dim_out, kernel_size=kernel_size, padding=kernel_size//2),
642
+ )
643
+
644
+ def ResnetBlock(dim, *, kernel_size=3, groups=32):
645
+ return Residual(
646
+ ConvLayer(dim, dim, kernel_size=kernel_size, groups=groups),
647
+ ConvLayer(dim, dim, kernel_size=kernel_size, groups=groups),
648
+ )
649
+
650
+ class SelfAttention(nn.Module):
651
+ def __init__(self, dim, out_dim, *, heads=8, key_dim=32, value_dim=32):
652
+ super().__init__()
653
+ self.dim = dim
654
+ self.out_dim = dim
655
+ self.heads = heads
656
+ self.key_dim = key_dim
657
+
658
+ self.to_k = nn.Linear(dim, key_dim)
659
+ self.to_v = nn.Linear(dim, value_dim)
660
+ self.to_q = nn.Linear(dim, key_dim * heads)
661
+ self.to_out = nn.Linear(value_dim * heads, out_dim)
662
+
663
+ def forward(self, x):
664
+ shape = x.shape
665
+ x = einops.rearrange(x, 'b c ... -> b (...) c')
666
+
667
+ k = self.to_k(x)
668
+ v = self.to_v(x)
669
+ q = self.to_q(x)
670
+ q = einops.rearrange(q, 'b n (h c) -> b (n h) c', h=self.heads)
671
+ if hasattr(nn.functional, "scaled_dot_product_attention"):
672
+ result = F.scaled_dot_product_attention(q, k, v)
673
+ else:
674
+ attention_scores = torch.bmm(q, k.transpose(-2, -1))
675
+ attention_probs = torch.softmax(attention_scores.float() / math.sqrt(self.key_dim), dim=-1).type(attention_scores.dtype)
676
+ result = torch.bmm(attention_probs, v)
677
+ result = einops.rearrange(result, 'b (n h) c -> b n (h c)', h=self.heads)
678
+ out = self.to_out(result)
679
+
680
+ out = einops.rearrange(out, 'b n c -> b c n')
681
+ out = torch.reshape(out, (shape[0], self.out_dim, *shape[2:]))
682
+ return out
683
+
684
+ def SelfAttentionBlock(dim, attention_dim, *, heads=8, groups=32):
685
+ if not attention_dim:
686
+ attention_dim = dim // heads
687
+ return Residual(
688
+ nn.GroupNorm(groups, dim),
689
+ SelfAttention(dim, dim, heads=heads, key_dim=attention_dim, value_dim=attention_dim),
690
+ )
691
+
692
+ class Discriminator2D(ModelMixin, ConfigMixin):
693
+ @register_to_config
694
+ def __init__(
695
+ self,
696
+ in_channels: int = 8,
697
+ out_channels: int = 1,
698
+ block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024, 1024),
699
+ block_repeats: Tuple[int] = (2, 2, 2, 2, 2),
700
+ downsample_blocks: Tuple[int] = (0, 1, 2),
701
+ attention_blocks: Tuple[int] = (1, 2, 3, 4),
702
+ mlp_hidden_channels: Tuple[int] = (2048, 2048, 2048),
703
+ mlp_uses_norm: bool = True,
704
+ attention_dim: Optional[int] = None,
705
+ attention_heads: int = 8,
706
+ groups: int = 32,
707
+ embedding_dim: int = 768,
708
+ ):
709
+ super().__init__()
710
+
711
+ self.blocks = nn.ModuleList([])
712
+
713
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], 7, padding=3)
714
+
715
+ for i in range(0, len(block_out_channels) - 1):
716
+ block_in = block_out_channels[i]
717
+ block_out = block_out_channels[i + 1]
718
+ block = nn.Sequential()
719
+ for j in range(0, block_repeats[i]):
720
+ if i in attention_blocks:
721
+ block.append(SelfAttentionBlock(block_in, attention_dim, heads=attention_heads, groups=groups))
722
+ block.append(ResnetBlock(block_in, groups=groups))
723
+ if i in downsample_blocks:
724
+ block.append(Downsample(block_in, block_out))
725
+ elif block_in != block_out:
726
+ block.append(nn.Conv2d(block_in, block_out, 1))
727
+ self.blocks.append(block)
728
+
729
+ # A simple MLP to make the final decision based on statistics from
730
+ # the output of every block
731
+ self.to_out = nn.Sequential()
732
+ d_channels = 2 * sum(block_out_channels[1:]) + embedding_dim
733
+ for c in mlp_hidden_channels:
734
+ self.to_out.append(nn.Linear(d_channels, c))
735
+ if mlp_uses_norm:
736
+ self.to_out.append(nn.GroupNorm(groups, c))
737
+ self.to_out.append(nn.SiLU())
738
+ d_channels = c
739
+ self.to_out.append(nn.Linear(d_channels, out_channels))
740
+
741
+ self.gradient_checkpointing = False
742
+
743
+ def enable_gradient_checkpointing(self):
744
+ self.gradient_checkpointing = True
745
+
746
+ def disable_gradient_checkpointing(self):
747
+ self.gradient_checkpointing = False
748
+
749
+ def forward(self, x, encoder_hidden_states):
750
+ x = self.conv_in(x)
751
+ if self.config.embedding_dim != 0:
752
+ d = einops.reduce(encoder_hidden_states, 'b n c -> b c', 'mean')
753
+ else:
754
+ d = torch.zeros([x.shape[0], 0], device=x.device, dtype=x.dtype)
755
+ for block in self.blocks:
756
+ if self.gradient_checkpointing:
757
+ x = torch.utils.checkpoint.checkpoint(block, x)
758
+ else:
759
+ x = block(x)
760
+ x_mean = einops.reduce(x, 'b c ... -> b c', 'mean')
761
+ x_max = einops.reduce(x, 'b c ... -> b c', 'max')
762
+ d = torch.cat([d, x_mean, x_max], dim=-1)
763
+ return self.to_out(d)
764
+
StableTuner_RunPod_Fix/lion_pytorch.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, Callable
2
+
3
+ import torch
4
+ from torch.optim.optimizer import Optimizer
5
+
6
+ # functions
7
+
8
+ def exists(val):
9
+ return val is not None
10
+
11
+ # update functions
12
+
13
+ def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
14
+ # stepweight decay
15
+
16
+ p.data.mul_(1 - lr * wd)
17
+
18
+ # weight update
19
+
20
+ update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1 - beta1).sign_()
21
+ p.add_(update, alpha = -lr)
22
+
23
+ # decay the momentum running average coefficient
24
+
25
+ exp_avg.mul_(beta2).add_(grad, alpha = 1 - beta2)
26
+
27
+ # class
28
+
29
+ class Lion(Optimizer):
30
+ def __init__(
31
+ self,
32
+ params,
33
+ lr: float = 1e-4,
34
+ betas: Tuple[float, float] = (0.9, 0.99),
35
+ weight_decay: float = 0.0,
36
+ use_triton: bool = False
37
+ ):
38
+ assert lr > 0.
39
+ assert all([0. <= beta <= 1. for beta in betas])
40
+
41
+ defaults = dict(
42
+ lr = lr,
43
+ betas = betas,
44
+ weight_decay = weight_decay
45
+ )
46
+
47
+ super().__init__(params, defaults)
48
+
49
+ self.update_fn = update_fn
50
+
51
+ if use_triton:
52
+ from lion_pytorch.triton import update_fn as triton_update_fn
53
+ self.update_fn = triton_update_fn
54
+
55
+ @torch.no_grad()
56
+ def step(
57
+ self,
58
+ closure: Optional[Callable] = None
59
+ ):
60
+
61
+ loss = None
62
+ if exists(closure):
63
+ with torch.enable_grad():
64
+ loss = closure()
65
+
66
+ for group in self.param_groups:
67
+ for p in filter(lambda p: exists(p.grad), group['params']):
68
+
69
+ grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p]
70
+
71
+ # init state - exponential moving average of gradient values
72
+
73
+ if len(state) == 0:
74
+ state['exp_avg'] = torch.zeros_like(p)
75
+
76
+ exp_avg = state['exp_avg']
77
+
78
+ self.update_fn(
79
+ p,
80
+ grad,
81
+ exp_avg,
82
+ lr,
83
+ wd,
84
+ beta1,
85
+ beta2
86
+ )
87
+
88
+ return loss
StableTuner_RunPod_Fix/lora_utils.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+
6
+ import math
7
+ import os
8
+ import torch
9
+
10
+ from trainer_util import *
11
+
12
+
13
+ class LoRAModule(torch.nn.Module):
14
+ """
15
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
16
+ """
17
+
18
+ def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
19
+ """ if alpha == 0 or None, alpha is rank (no scaling). """
20
+ super().__init__()
21
+ self.lora_name = lora_name
22
+ self.lora_dim = lora_dim
23
+
24
+ if org_module.__class__.__name__ == 'Conv2d':
25
+ in_dim = org_module.in_channels
26
+ out_dim = org_module.out_channels
27
+ self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
28
+ self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
29
+ else:
30
+ in_dim = org_module.in_features
31
+ out_dim = org_module.out_features
32
+ self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
33
+ self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
34
+
35
+ if type(alpha) == torch.Tensor:
36
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
37
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
38
+ self.scale = alpha / self.lora_dim
39
+ self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
40
+
41
+ # same as microsoft's
42
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
43
+ torch.nn.init.zeros_(self.lora_up.weight)
44
+
45
+ self.multiplier = multiplier
46
+ self.org_module = org_module # remove in applying
47
+
48
+ def apply_to(self):
49
+ self.org_forward = self.org_module.forward
50
+ self.org_module.forward = self.forward
51
+ del self.org_module
52
+
53
+ def forward(self, x):
54
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
55
+
56
+
57
+ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
58
+ if network_dim is None:
59
+ network_dim = 4 # default
60
+ network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
61
+ return network
62
+
63
+
64
+ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
65
+ if os.path.splitext(file)[1] == '.safetensors':
66
+ from safetensors.torch import load_file, safe_open
67
+ weights_sd = load_file(file)
68
+ else:
69
+ weights_sd = torch.load(file, map_location='cpu')
70
+
71
+ # get dim (rank)
72
+ network_alpha = None
73
+ network_dim = None
74
+ for key, value in weights_sd.items():
75
+ if network_alpha is None and 'alpha' in key:
76
+ network_alpha = value
77
+ if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
78
+ network_dim = value.size()[0]
79
+
80
+ if network_alpha is None:
81
+ network_alpha = network_dim
82
+
83
+ network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
84
+ network.weights_sd = weights_sd
85
+ return network
86
+
87
+
88
+ class LoRANetwork(torch.nn.Module):
89
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
90
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
91
+ LORA_PREFIX_UNET = 'lora_unet'
92
+ LORA_PREFIX_TEXT_ENCODER = 'lora_te'
93
+
94
+ def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
95
+ super().__init__()
96
+ self.multiplier = multiplier
97
+ self.lora_dim = lora_dim
98
+ self.alpha = alpha
99
+
100
+ # create module instances
101
+ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
102
+ loras = []
103
+ for name, module in root_module.named_modules():
104
+ if module.__class__.__name__ in target_replace_modules:
105
+ for child_name, child_module in module.named_modules():
106
+ if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
107
+ lora_name = prefix + '.' + name + '.' + child_name
108
+ lora_name = lora_name.replace('.', '_')
109
+ lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
110
+ loras.append(lora)
111
+ return loras
112
+
113
+ self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER,
114
+ text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
115
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
116
+
117
+ self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
118
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
119
+
120
+ self.weights_sd = None
121
+
122
+ # assertion
123
+ names = set()
124
+ for lora in self.text_encoder_loras + self.unet_loras:
125
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
126
+ names.add(lora.lora_name)
127
+
128
+ def load_weights(self, file):
129
+ if os.path.splitext(file)[1] == '.safetensors':
130
+ from safetensors.torch import load_file, safe_open
131
+ self.weights_sd = load_file(file)
132
+ else:
133
+ self.weights_sd = torch.load(file, map_location='cpu')
134
+
135
+ def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
136
+ if self.weights_sd:
137
+ weights_has_text_encoder = weights_has_unet = False
138
+ for key in self.weights_sd.keys():
139
+ if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
140
+ weights_has_text_encoder = True
141
+ elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
142
+ weights_has_unet = True
143
+
144
+ if apply_text_encoder is None:
145
+ apply_text_encoder = weights_has_text_encoder
146
+ else:
147
+ assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
148
+
149
+ if apply_unet is None:
150
+ apply_unet = weights_has_unet
151
+ else:
152
+ assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
153
+ else:
154
+ assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
155
+
156
+ if apply_text_encoder:
157
+ print("enable LoRA for text encoder")
158
+ else:
159
+ self.text_encoder_loras = []
160
+
161
+ if apply_unet:
162
+ print("enable LoRA for U-Net")
163
+ else:
164
+ self.unet_loras = []
165
+
166
+ for lora in self.text_encoder_loras + self.unet_loras:
167
+ lora.apply_to()
168
+ self.add_module(lora.lora_name, lora)
169
+
170
+ if self.weights_sd:
171
+ # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
172
+ info = self.load_state_dict(self.weights_sd, False)
173
+ print(f"weights are loaded: {info}")
174
+
175
+ def enable_gradient_checkpointing(self):
176
+ # not supported
177
+ pass
178
+
179
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
180
+ def enumerate_params(loras):
181
+ params = []
182
+ for lora in loras:
183
+ params.extend(lora.parameters())
184
+ return params
185
+
186
+ self.requires_grad_(True)
187
+ all_params = []
188
+
189
+ if self.text_encoder_loras:
190
+ param_data = {'params': enumerate_params(self.text_encoder_loras)}
191
+ if text_encoder_lr is not None:
192
+ param_data['lr'] = text_encoder_lr
193
+ all_params.append(param_data)
194
+
195
+ if self.unet_loras:
196
+ param_data = {'params': enumerate_params(self.unet_loras)}
197
+ if unet_lr is not None:
198
+ param_data['lr'] = unet_lr
199
+ all_params.append(param_data)
200
+
201
+ return all_params
202
+
203
+ def prepare_grad_etc(self, text_encoder, unet):
204
+ self.requires_grad_(True)
205
+
206
+ def on_epoch_start(self, text_encoder, unet):
207
+ self.train()
208
+
209
+ def get_trainable_params(self):
210
+ return self.parameters()
211
+
212
+ def save_weights(self, file, dtype, metadata):
213
+ if metadata is not None and len(metadata) == 0:
214
+ metadata = None
215
+
216
+ state_dict = self.state_dict()
217
+
218
+ if dtype is not None:
219
+ for key in list(state_dict.keys()):
220
+ v = state_dict[key]
221
+ v = v.detach().clone().to("cpu").to(dtype)
222
+ state_dict[key] = v
223
+
224
+ if os.path.splitext(file)[1] == '.safetensors':
225
+ from safetensors.torch import save_file
226
+
227
+ # Precalculate model hashes to save time on indexing
228
+ if metadata is None:
229
+ metadata = {}
230
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
231
+ metadata["sshs_model_hash"] = model_hash
232
+ metadata["sshs_legacy_hash"] = legacy_hash
233
+
234
+ save_file(state_dict, file, metadata)
235
+ else:
236
+ torch.save(state_dict, file)
StableTuner_RunPod_Fix/model_util.py ADDED
@@ -0,0 +1,1543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v1: split from train_db_fixed.py.
2
+ # v2: support safetensors
3
+
4
+ import math
5
+ import os
6
+ import torch
7
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
8
+ from diffusers import (
9
+ AutoencoderKL,
10
+ DDIMScheduler,
11
+ StableDiffusionPipeline,
12
+ UNet2DConditionModel,
13
+ )
14
+ from safetensors.torch import load_file, save_file
15
+
16
+ # DiffUsers版StableDiffusionのモデルパラメータ
17
+ NUM_TRAIN_TIMESTEPS = 1000
18
+ BETA_START = 0.00085
19
+ BETA_END = 0.0120
20
+
21
+ UNET_PARAMS_MODEL_CHANNELS = 320
22
+ UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
23
+ UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
24
+ UNET_PARAMS_IMAGE_SIZE = 32 # unused
25
+ UNET_PARAMS_IN_CHANNELS = 4
26
+ UNET_PARAMS_OUT_CHANNELS = 4
27
+ UNET_PARAMS_NUM_RES_BLOCKS = 2
28
+ UNET_PARAMS_CONTEXT_DIM = 768
29
+ UNET_PARAMS_NUM_HEADS = 8
30
+
31
+ VAE_PARAMS_Z_CHANNELS = 4
32
+ VAE_PARAMS_RESOLUTION = 256
33
+ VAE_PARAMS_IN_CHANNELS = 3
34
+ VAE_PARAMS_OUT_CH = 3
35
+ VAE_PARAMS_CH = 128
36
+ VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
37
+ VAE_PARAMS_NUM_RES_BLOCKS = 2
38
+
39
+ # V2
40
+ V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
41
+ V2_UNET_PARAMS_CONTEXT_DIM = 1024
42
+
43
+ # Diffusersの設定を読み込むための参照モデル
44
+ DIFFUSERS_REF_MODEL_ID_V1 = 'runwayml/stable-diffusion-v1-5'
45
+ DIFFUSERS_REF_MODEL_ID_V2 = 'stabilityai/stable-diffusion-2-1'
46
+
47
+
48
+ # region StableDiffusion->Diffusersの変換コード
49
+ # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
50
+
51
+
52
+ def shave_segments(path, n_shave_prefix_segments=1):
53
+ """
54
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
55
+ """
56
+ if n_shave_prefix_segments >= 0:
57
+ return '.'.join(path.split('.')[n_shave_prefix_segments:])
58
+ else:
59
+ return '.'.join(path.split('.')[:n_shave_prefix_segments])
60
+
61
+
62
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
63
+ """
64
+ Updates paths inside resnets to the new naming scheme (local renaming)
65
+ """
66
+ mapping = []
67
+ for old_item in old_list:
68
+ new_item = old_item.replace('in_layers.0', 'norm1')
69
+ new_item = new_item.replace('in_layers.2', 'conv1')
70
+
71
+ new_item = new_item.replace('out_layers.0', 'norm2')
72
+ new_item = new_item.replace('out_layers.3', 'conv2')
73
+
74
+ new_item = new_item.replace('emb_layers.1', 'time_emb_proj')
75
+ new_item = new_item.replace('skip_connection', 'conv_shortcut')
76
+
77
+ new_item = shave_segments(
78
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments
79
+ )
80
+
81
+ mapping.append({'old': old_item, 'new': new_item})
82
+
83
+ return mapping
84
+
85
+
86
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
87
+ """
88
+ Updates paths inside resnets to the new naming scheme (local renaming)
89
+ """
90
+ mapping = []
91
+ for old_item in old_list:
92
+ new_item = old_item
93
+
94
+ new_item = new_item.replace('nin_shortcut', 'conv_shortcut')
95
+ new_item = shave_segments(
96
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments
97
+ )
98
+
99
+ mapping.append({'old': old_item, 'new': new_item})
100
+
101
+ return mapping
102
+
103
+
104
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
105
+ """
106
+ Updates paths inside attentions to the new naming scheme (local renaming)
107
+ """
108
+ mapping = []
109
+ for old_item in old_list:
110
+ new_item = old_item
111
+
112
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
113
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
114
+
115
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
116
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
117
+
118
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
119
+
120
+ mapping.append({'old': old_item, 'new': new_item})
121
+
122
+ return mapping
123
+
124
+
125
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
126
+ """
127
+ Updates paths inside attentions to the new naming scheme (local renaming)
128
+ """
129
+ mapping = []
130
+ for old_item in old_list:
131
+ new_item = old_item
132
+
133
+ new_item = new_item.replace('norm.weight', 'group_norm.weight')
134
+ new_item = new_item.replace('norm.bias', 'group_norm.bias')
135
+
136
+ new_item = new_item.replace('q.weight', 'query.weight')
137
+ new_item = new_item.replace('q.bias', 'query.bias')
138
+
139
+ new_item = new_item.replace('k.weight', 'key.weight')
140
+ new_item = new_item.replace('k.bias', 'key.bias')
141
+
142
+ new_item = new_item.replace('v.weight', 'value.weight')
143
+ new_item = new_item.replace('v.bias', 'value.bias')
144
+
145
+ new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
146
+ new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
147
+
148
+ new_item = shave_segments(
149
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments
150
+ )
151
+
152
+ mapping.append({'old': old_item, 'new': new_item})
153
+
154
+ return mapping
155
+
156
+
157
+ def assign_to_checkpoint(
158
+ paths,
159
+ checkpoint,
160
+ old_checkpoint,
161
+ attention_paths_to_split=None,
162
+ additional_replacements=None,
163
+ config=None,
164
+ ):
165
+ """
166
+ This does the final conversion step: take locally converted weights and apply a global renaming
167
+ to them. It splits attention layers, and takes into account additional replacements
168
+ that may arise.
169
+ Assigns the weights to the new checkpoint.
170
+ """
171
+ assert isinstance(
172
+ paths, list
173
+ ), "Paths should be a list of dicts containing 'old' and 'new' keys."
174
+
175
+ # Splits the attention layers into three variables.
176
+ if attention_paths_to_split is not None:
177
+ for path, path_map in attention_paths_to_split.items():
178
+ old_tensor = old_checkpoint[path]
179
+ channels = old_tensor.shape[0] // 3
180
+
181
+ target_shape = (
182
+ (-1, channels) if len(old_tensor.shape) == 3 else (-1)
183
+ )
184
+
185
+ num_heads = old_tensor.shape[0] // config['num_head_channels'] // 3
186
+
187
+ old_tensor = old_tensor.reshape(
188
+ (num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
189
+ )
190
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
191
+
192
+ checkpoint[path_map['query']] = query.reshape(target_shape)
193
+ checkpoint[path_map['key']] = key.reshape(target_shape)
194
+ checkpoint[path_map['value']] = value.reshape(target_shape)
195
+
196
+ for path in paths:
197
+ new_path = path['new']
198
+
199
+ # These have already been assigned
200
+ if (
201
+ attention_paths_to_split is not None
202
+ and new_path in attention_paths_to_split
203
+ ):
204
+ continue
205
+
206
+ # Global renaming happens here
207
+ new_path = new_path.replace('middle_block.0', 'mid_block.resnets.0')
208
+ new_path = new_path.replace('middle_block.1', 'mid_block.attentions.0')
209
+ new_path = new_path.replace('middle_block.2', 'mid_block.resnets.1')
210
+
211
+ if additional_replacements is not None:
212
+ for replacement in additional_replacements:
213
+ new_path = new_path.replace(
214
+ replacement['old'], replacement['new']
215
+ )
216
+
217
+ # proj_attn.weight has to be converted from conv 1D to linear
218
+ if 'proj_attn.weight' in new_path:
219
+ checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0]
220
+ else:
221
+ checkpoint[new_path] = old_checkpoint[path['old']]
222
+
223
+
224
+ def conv_attn_to_linear(checkpoint):
225
+ keys = list(checkpoint.keys())
226
+ attn_keys = ['query.weight', 'key.weight', 'value.weight']
227
+ for key in keys:
228
+ if '.'.join(key.split('.')[-2:]) in attn_keys:
229
+ if checkpoint[key].ndim > 2:
230
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
231
+ elif 'proj_attn.weight' in key:
232
+ if checkpoint[key].ndim > 2:
233
+ checkpoint[key] = checkpoint[key][:, :, 0]
234
+
235
+
236
+ def linear_transformer_to_conv(checkpoint):
237
+ keys = list(checkpoint.keys())
238
+ tf_keys = ['proj_in.weight', 'proj_out.weight']
239
+ for key in keys:
240
+ if '.'.join(key.split('.')[-2:]) in tf_keys:
241
+ if checkpoint[key].ndim == 2:
242
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
243
+
244
+
245
+ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
246
+ """
247
+ Takes a state dict and a config, and returns a converted checkpoint.
248
+ """
249
+
250
+ # extract state_dict for UNet
251
+ unet_state_dict = {}
252
+ unet_key = 'model.diffusion_model.'
253
+ keys = list(checkpoint.keys())
254
+ for key in keys:
255
+ if key.startswith(unet_key):
256
+ unet_state_dict[key.replace(unet_key, '')] = checkpoint.pop(key)
257
+
258
+ new_checkpoint = {}
259
+
260
+ new_checkpoint['time_embedding.linear_1.weight'] = unet_state_dict[
261
+ 'time_embed.0.weight'
262
+ ]
263
+ new_checkpoint['time_embedding.linear_1.bias'] = unet_state_dict[
264
+ 'time_embed.0.bias'
265
+ ]
266
+ new_checkpoint['time_embedding.linear_2.weight'] = unet_state_dict[
267
+ 'time_embed.2.weight'
268
+ ]
269
+ new_checkpoint['time_embedding.linear_2.bias'] = unet_state_dict[
270
+ 'time_embed.2.bias'
271
+ ]
272
+
273
+ new_checkpoint['conv_in.weight'] = unet_state_dict[
274
+ 'input_blocks.0.0.weight'
275
+ ]
276
+ new_checkpoint['conv_in.bias'] = unet_state_dict['input_blocks.0.0.bias']
277
+
278
+ new_checkpoint['conv_norm_out.weight'] = unet_state_dict['out.0.weight']
279
+ new_checkpoint['conv_norm_out.bias'] = unet_state_dict['out.0.bias']
280
+ new_checkpoint['conv_out.weight'] = unet_state_dict['out.2.weight']
281
+ new_checkpoint['conv_out.bias'] = unet_state_dict['out.2.bias']
282
+
283
+ # Retrieves the keys for the input blocks only
284
+ num_input_blocks = len(
285
+ {
286
+ '.'.join(layer.split('.')[:2])
287
+ for layer in unet_state_dict
288
+ if 'input_blocks' in layer
289
+ }
290
+ )
291
+ input_blocks = {
292
+ layer_id: [
293
+ key
294
+ for key in unet_state_dict
295
+ if f'input_blocks.{layer_id}.' in key
296
+ ]
297
+ for layer_id in range(num_input_blocks)
298
+ }
299
+
300
+ # Retrieves the keys for the middle blocks only
301
+ num_middle_blocks = len(
302
+ {
303
+ '.'.join(layer.split('.')[:2])
304
+ for layer in unet_state_dict
305
+ if 'middle_block' in layer
306
+ }
307
+ )
308
+ middle_blocks = {
309
+ layer_id: [
310
+ key
311
+ for key in unet_state_dict
312
+ if f'middle_block.{layer_id}.' in key
313
+ ]
314
+ for layer_id in range(num_middle_blocks)
315
+ }
316
+
317
+ # Retrieves the keys for the output blocks only
318
+ num_output_blocks = len(
319
+ {
320
+ '.'.join(layer.split('.')[:2])
321
+ for layer in unet_state_dict
322
+ if 'output_blocks' in layer
323
+ }
324
+ )
325
+ output_blocks = {
326
+ layer_id: [
327
+ key
328
+ for key in unet_state_dict
329
+ if f'output_blocks.{layer_id}.' in key
330
+ ]
331
+ for layer_id in range(num_output_blocks)
332
+ }
333
+
334
+ for i in range(1, num_input_blocks):
335
+ block_id = (i - 1) // (config['layers_per_block'] + 1)
336
+ layer_in_block_id = (i - 1) % (config['layers_per_block'] + 1)
337
+
338
+ resnets = [
339
+ key
340
+ for key in input_blocks[i]
341
+ if f'input_blocks.{i}.0' in key
342
+ and f'input_blocks.{i}.0.op' not in key
343
+ ]
344
+ attentions = [
345
+ key for key in input_blocks[i] if f'input_blocks.{i}.1' in key
346
+ ]
347
+
348
+ if f'input_blocks.{i}.0.op.weight' in unet_state_dict:
349
+ new_checkpoint[
350
+ f'down_blocks.{block_id}.downsamplers.0.conv.weight'
351
+ ] = unet_state_dict.pop(f'input_blocks.{i}.0.op.weight')
352
+ new_checkpoint[
353
+ f'down_blocks.{block_id}.downsamplers.0.conv.bias'
354
+ ] = unet_state_dict.pop(f'input_blocks.{i}.0.op.bias')
355
+
356
+ paths = renew_resnet_paths(resnets)
357
+ meta_path = {
358
+ 'old': f'input_blocks.{i}.0',
359
+ 'new': f'down_blocks.{block_id}.resnets.{layer_in_block_id}',
360
+ }
361
+ assign_to_checkpoint(
362
+ paths,
363
+ new_checkpoint,
364
+ unet_state_dict,
365
+ additional_replacements=[meta_path],
366
+ config=config,
367
+ )
368
+
369
+ if len(attentions):
370
+ paths = renew_attention_paths(attentions)
371
+ meta_path = {
372
+ 'old': f'input_blocks.{i}.1',
373
+ 'new': f'down_blocks.{block_id}.attentions.{layer_in_block_id}',
374
+ }
375
+ assign_to_checkpoint(
376
+ paths,
377
+ new_checkpoint,
378
+ unet_state_dict,
379
+ additional_replacements=[meta_path],
380
+ config=config,
381
+ )
382
+
383
+ resnet_0 = middle_blocks[0]
384
+ attentions = middle_blocks[1]
385
+ resnet_1 = middle_blocks[2]
386
+
387
+ resnet_0_paths = renew_resnet_paths(resnet_0)
388
+ assign_to_checkpoint(
389
+ resnet_0_paths, new_checkpoint, unet_state_dict, config=config
390
+ )
391
+
392
+ resnet_1_paths = renew_resnet_paths(resnet_1)
393
+ assign_to_checkpoint(
394
+ resnet_1_paths, new_checkpoint, unet_state_dict, config=config
395
+ )
396
+
397
+ attentions_paths = renew_attention_paths(attentions)
398
+ meta_path = {'old': 'middle_block.1', 'new': 'mid_block.attentions.0'}
399
+ assign_to_checkpoint(
400
+ attentions_paths,
401
+ new_checkpoint,
402
+ unet_state_dict,
403
+ additional_replacements=[meta_path],
404
+ config=config,
405
+ )
406
+
407
+ for i in range(num_output_blocks):
408
+ block_id = i // (config['layers_per_block'] + 1)
409
+ layer_in_block_id = i % (config['layers_per_block'] + 1)
410
+ output_block_layers = [
411
+ shave_segments(name, 2) for name in output_blocks[i]
412
+ ]
413
+ output_block_list = {}
414
+
415
+ for layer in output_block_layers:
416
+ layer_id, layer_name = layer.split('.')[0], shave_segments(
417
+ layer, 1
418
+ )
419
+ if layer_id in output_block_list:
420
+ output_block_list[layer_id].append(layer_name)
421
+ else:
422
+ output_block_list[layer_id] = [layer_name]
423
+
424
+ if len(output_block_list) > 1:
425
+ resnets = [
426
+ key
427
+ for key in output_blocks[i]
428
+ if f'output_blocks.{i}.0' in key
429
+ ]
430
+ attentions = [
431
+ key
432
+ for key in output_blocks[i]
433
+ if f'output_blocks.{i}.1' in key
434
+ ]
435
+
436
+ resnet_0_paths = renew_resnet_paths(resnets)
437
+ paths = renew_resnet_paths(resnets)
438
+
439
+ meta_path = {
440
+ 'old': f'output_blocks.{i}.0',
441
+ 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}',
442
+ }
443
+ assign_to_checkpoint(
444
+ paths,
445
+ new_checkpoint,
446
+ unet_state_dict,
447
+ additional_replacements=[meta_path],
448
+ config=config,
449
+ )
450
+
451
+ # オリジナル:
452
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
453
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
454
+
455
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
456
+ for l in output_block_list.values():
457
+ l.sort()
458
+
459
+ if ['conv.bias', 'conv.weight'] in output_block_list.values():
460
+ index = list(output_block_list.values()).index(
461
+ ['conv.bias', 'conv.weight']
462
+ )
463
+ new_checkpoint[
464
+ f'up_blocks.{block_id}.upsamplers.0.conv.bias'
465
+ ] = unet_state_dict[f'output_blocks.{i}.{index}.conv.bias']
466
+ new_checkpoint[
467
+ f'up_blocks.{block_id}.upsamplers.0.conv.weight'
468
+ ] = unet_state_dict[f'output_blocks.{i}.{index}.conv.weight']
469
+
470
+ # Clear attentions as they have been attributed above.
471
+ if len(attentions) == 2:
472
+ attentions = []
473
+
474
+ if len(attentions):
475
+ paths = renew_attention_paths(attentions)
476
+ meta_path = {
477
+ 'old': f'output_blocks.{i}.1',
478
+ 'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}',
479
+ }
480
+ assign_to_checkpoint(
481
+ paths,
482
+ new_checkpoint,
483
+ unet_state_dict,
484
+ additional_replacements=[meta_path],
485
+ config=config,
486
+ )
487
+ else:
488
+ resnet_0_paths = renew_resnet_paths(
489
+ output_block_layers, n_shave_prefix_segments=1
490
+ )
491
+ for path in resnet_0_paths:
492
+ old_path = '.'.join(['output_blocks', str(i), path['old']])
493
+ new_path = '.'.join(
494
+ [
495
+ 'up_blocks',
496
+ str(block_id),
497
+ 'resnets',
498
+ str(layer_in_block_id),
499
+ path['new'],
500
+ ]
501
+ )
502
+
503
+ new_checkpoint[new_path] = unet_state_dict[old_path]
504
+
505
+ # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
506
+ if v2:
507
+ linear_transformer_to_conv(new_checkpoint)
508
+
509
+ return new_checkpoint
510
+
511
+
512
+ def convert_ldm_vae_checkpoint(checkpoint, config):
513
+ # extract state dict for VAE
514
+ vae_state_dict = {}
515
+ vae_key = 'first_stage_model.'
516
+ keys = list(checkpoint.keys())
517
+ for key in keys:
518
+ if key.startswith(vae_key):
519
+ vae_state_dict[key.replace(vae_key, '')] = checkpoint.get(key)
520
+ # if len(vae_state_dict) == 0:
521
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
522
+ # vae_state_dict = checkpoint
523
+
524
+ new_checkpoint = {}
525
+
526
+ new_checkpoint['encoder.conv_in.weight'] = vae_state_dict[
527
+ 'encoder.conv_in.weight'
528
+ ]
529
+ new_checkpoint['encoder.conv_in.bias'] = vae_state_dict[
530
+ 'encoder.conv_in.bias'
531
+ ]
532
+ new_checkpoint['encoder.conv_out.weight'] = vae_state_dict[
533
+ 'encoder.conv_out.weight'
534
+ ]
535
+ new_checkpoint['encoder.conv_out.bias'] = vae_state_dict[
536
+ 'encoder.conv_out.bias'
537
+ ]
538
+ new_checkpoint['encoder.conv_norm_out.weight'] = vae_state_dict[
539
+ 'encoder.norm_out.weight'
540
+ ]
541
+ new_checkpoint['encoder.conv_norm_out.bias'] = vae_state_dict[
542
+ 'encoder.norm_out.bias'
543
+ ]
544
+
545
+ new_checkpoint['decoder.conv_in.weight'] = vae_state_dict[
546
+ 'decoder.conv_in.weight'
547
+ ]
548
+ new_checkpoint['decoder.conv_in.bias'] = vae_state_dict[
549
+ 'decoder.conv_in.bias'
550
+ ]
551
+ new_checkpoint['decoder.conv_out.weight'] = vae_state_dict[
552
+ 'decoder.conv_out.weight'
553
+ ]
554
+ new_checkpoint['decoder.conv_out.bias'] = vae_state_dict[
555
+ 'decoder.conv_out.bias'
556
+ ]
557
+ new_checkpoint['decoder.conv_norm_out.weight'] = vae_state_dict[
558
+ 'decoder.norm_out.weight'
559
+ ]
560
+ new_checkpoint['decoder.conv_norm_out.bias'] = vae_state_dict[
561
+ 'decoder.norm_out.bias'
562
+ ]
563
+
564
+ new_checkpoint['quant_conv.weight'] = vae_state_dict['quant_conv.weight']
565
+ new_checkpoint['quant_conv.bias'] = vae_state_dict['quant_conv.bias']
566
+ new_checkpoint['post_quant_conv.weight'] = vae_state_dict[
567
+ 'post_quant_conv.weight'
568
+ ]
569
+ new_checkpoint['post_quant_conv.bias'] = vae_state_dict[
570
+ 'post_quant_conv.bias'
571
+ ]
572
+
573
+ # Retrieves the keys for the encoder down blocks only
574
+ num_down_blocks = len(
575
+ {
576
+ '.'.join(layer.split('.')[:3])
577
+ for layer in vae_state_dict
578
+ if 'encoder.down' in layer
579
+ }
580
+ )
581
+ down_blocks = {
582
+ layer_id: [key for key in vae_state_dict if f'down.{layer_id}' in key]
583
+ for layer_id in range(num_down_blocks)
584
+ }
585
+
586
+ # Retrieves the keys for the decoder up blocks only
587
+ num_up_blocks = len(
588
+ {
589
+ '.'.join(layer.split('.')[:3])
590
+ for layer in vae_state_dict
591
+ if 'decoder.up' in layer
592
+ }
593
+ )
594
+ up_blocks = {
595
+ layer_id: [key for key in vae_state_dict if f'up.{layer_id}' in key]
596
+ for layer_id in range(num_up_blocks)
597
+ }
598
+
599
+ for i in range(num_down_blocks):
600
+ resnets = [
601
+ key
602
+ for key in down_blocks[i]
603
+ if f'down.{i}' in key and f'down.{i}.downsample' not in key
604
+ ]
605
+
606
+ if f'encoder.down.{i}.downsample.conv.weight' in vae_state_dict:
607
+ new_checkpoint[
608
+ f'encoder.down_blocks.{i}.downsamplers.0.conv.weight'
609
+ ] = vae_state_dict.pop(f'encoder.down.{i}.downsample.conv.weight')
610
+ new_checkpoint[
611
+ f'encoder.down_blocks.{i}.downsamplers.0.conv.bias'
612
+ ] = vae_state_dict.pop(f'encoder.down.{i}.downsample.conv.bias')
613
+
614
+ paths = renew_vae_resnet_paths(resnets)
615
+ meta_path = {
616
+ 'old': f'down.{i}.block',
617
+ 'new': f'down_blocks.{i}.resnets',
618
+ }
619
+ assign_to_checkpoint(
620
+ paths,
621
+ new_checkpoint,
622
+ vae_state_dict,
623
+ additional_replacements=[meta_path],
624
+ config=config,
625
+ )
626
+
627
+ mid_resnets = [key for key in vae_state_dict if 'encoder.mid.block' in key]
628
+ num_mid_res_blocks = 2
629
+ for i in range(1, num_mid_res_blocks + 1):
630
+ resnets = [
631
+ key for key in mid_resnets if f'encoder.mid.block_{i}' in key
632
+ ]
633
+
634
+ paths = renew_vae_resnet_paths(resnets)
635
+ meta_path = {
636
+ 'old': f'mid.block_{i}',
637
+ 'new': f'mid_block.resnets.{i - 1}',
638
+ }
639
+ assign_to_checkpoint(
640
+ paths,
641
+ new_checkpoint,
642
+ vae_state_dict,
643
+ additional_replacements=[meta_path],
644
+ config=config,
645
+ )
646
+
647
+ mid_attentions = [
648
+ key for key in vae_state_dict if 'encoder.mid.attn' in key
649
+ ]
650
+ paths = renew_vae_attention_paths(mid_attentions)
651
+ meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'}
652
+ assign_to_checkpoint(
653
+ paths,
654
+ new_checkpoint,
655
+ vae_state_dict,
656
+ additional_replacements=[meta_path],
657
+ config=config,
658
+ )
659
+ conv_attn_to_linear(new_checkpoint)
660
+
661
+ for i in range(num_up_blocks):
662
+ block_id = num_up_blocks - 1 - i
663
+ resnets = [
664
+ key
665
+ for key in up_blocks[block_id]
666
+ if f'up.{block_id}' in key and f'up.{block_id}.upsample' not in key
667
+ ]
668
+
669
+ if f'decoder.up.{block_id}.upsample.conv.weight' in vae_state_dict:
670
+ new_checkpoint[
671
+ f'decoder.up_blocks.{i}.upsamplers.0.conv.weight'
672
+ ] = vae_state_dict[f'decoder.up.{block_id}.upsample.conv.weight']
673
+ new_checkpoint[
674
+ f'decoder.up_blocks.{i}.upsamplers.0.conv.bias'
675
+ ] = vae_state_dict[f'decoder.up.{block_id}.upsample.conv.bias']
676
+
677
+ paths = renew_vae_resnet_paths(resnets)
678
+ meta_path = {
679
+ 'old': f'up.{block_id}.block',
680
+ 'new': f'up_blocks.{i}.resnets',
681
+ }
682
+ assign_to_checkpoint(
683
+ paths,
684
+ new_checkpoint,
685
+ vae_state_dict,
686
+ additional_replacements=[meta_path],
687
+ config=config,
688
+ )
689
+
690
+ mid_resnets = [key for key in vae_state_dict if 'decoder.mid.block' in key]
691
+ num_mid_res_blocks = 2
692
+ for i in range(1, num_mid_res_blocks + 1):
693
+ resnets = [
694
+ key for key in mid_resnets if f'decoder.mid.block_{i}' in key
695
+ ]
696
+
697
+ paths = renew_vae_resnet_paths(resnets)
698
+ meta_path = {
699
+ 'old': f'mid.block_{i}',
700
+ 'new': f'mid_block.resnets.{i - 1}',
701
+ }
702
+ assign_to_checkpoint(
703
+ paths,
704
+ new_checkpoint,
705
+ vae_state_dict,
706
+ additional_replacements=[meta_path],
707
+ config=config,
708
+ )
709
+
710
+ mid_attentions = [
711
+ key for key in vae_state_dict if 'decoder.mid.attn' in key
712
+ ]
713
+ paths = renew_vae_attention_paths(mid_attentions)
714
+ meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'}
715
+ assign_to_checkpoint(
716
+ paths,
717
+ new_checkpoint,
718
+ vae_state_dict,
719
+ additional_replacements=[meta_path],
720
+ config=config,
721
+ )
722
+ conv_attn_to_linear(new_checkpoint)
723
+ return new_checkpoint
724
+
725
+
726
+ def create_unet_diffusers_config(v2):
727
+ """
728
+ Creates a config for the diffusers based on the config of the LDM model.
729
+ """
730
+ # unet_params = original_config.model.params.unet_config.params
731
+
732
+ block_out_channels = [
733
+ UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT
734
+ ]
735
+
736
+ down_block_types = []
737
+ resolution = 1
738
+ for i in range(len(block_out_channels)):
739
+ block_type = (
740
+ 'CrossAttnDownBlock2D'
741
+ if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS
742
+ else 'DownBlock2D'
743
+ )
744
+ down_block_types.append(block_type)
745
+ if i != len(block_out_channels) - 1:
746
+ resolution *= 2
747
+
748
+ up_block_types = []
749
+ for i in range(len(block_out_channels)):
750
+ block_type = (
751
+ 'CrossAttnUpBlock2D'
752
+ if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS
753
+ else 'UpBlock2D'
754
+ )
755
+ up_block_types.append(block_type)
756
+ resolution //= 2
757
+
758
+ config = dict(
759
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
760
+ in_channels=UNET_PARAMS_IN_CHANNELS,
761
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
762
+ down_block_types=tuple(down_block_types),
763
+ up_block_types=tuple(up_block_types),
764
+ block_out_channels=tuple(block_out_channels),
765
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
766
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM
767
+ if not v2
768
+ else V2_UNET_PARAMS_CONTEXT_DIM,
769
+ attention_head_dim=UNET_PARAMS_NUM_HEADS
770
+ if not v2
771
+ else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
772
+ )
773
+
774
+ return config
775
+
776
+
777
+ def create_vae_diffusers_config():
778
+ """
779
+ Creates a config for the diffusers based on the config of the LDM model.
780
+ """
781
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
782
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
783
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
784
+ down_block_types = ['DownEncoderBlock2D'] * len(block_out_channels)
785
+ up_block_types = ['UpDecoderBlock2D'] * len(block_out_channels)
786
+
787
+ config = dict(
788
+ sample_size=VAE_PARAMS_RESOLUTION,
789
+ in_channels=VAE_PARAMS_IN_CHANNELS,
790
+ out_channels=VAE_PARAMS_OUT_CH,
791
+ down_block_types=tuple(down_block_types),
792
+ up_block_types=tuple(up_block_types),
793
+ block_out_channels=tuple(block_out_channels),
794
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
795
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
796
+ )
797
+ return config
798
+
799
+
800
+ def convert_ldm_clip_checkpoint_v1(checkpoint):
801
+ keys = list(checkpoint.keys())
802
+ text_model_dict = {}
803
+ for key in keys:
804
+ if key.startswith('cond_stage_model.transformer'):
805
+ text_model_dict[
806
+ key[len('cond_stage_model.transformer.') :]
807
+ ] = checkpoint[key]
808
+ return text_model_dict
809
+
810
+
811
+ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
812
+ # 嫌になるくらい違うぞ!
813
+ def convert_key(key):
814
+ if not key.startswith('cond_stage_model'):
815
+ return None
816
+
817
+ # common conversion
818
+ key = key.replace(
819
+ 'cond_stage_model.model.transformer.', 'text_model.encoder.'
820
+ )
821
+ key = key.replace('cond_stage_model.model.', 'text_model.')
822
+
823
+ if 'resblocks' in key:
824
+ # resblocks conversion
825
+ key = key.replace('.resblocks.', '.layers.')
826
+ if '.ln_' in key:
827
+ key = key.replace('.ln_', '.layer_norm')
828
+ elif '.mlp.' in key:
829
+ key = key.replace('.c_fc.', '.fc1.')
830
+ key = key.replace('.c_proj.', '.fc2.')
831
+ elif '.attn.out_proj' in key:
832
+ key = key.replace('.attn.out_proj.', '.self_attn.out_proj.')
833
+ elif '.attn.in_proj' in key:
834
+ key = None # 特殊なので後で処理する
835
+ else:
836
+ raise ValueError(f'unexpected key in SD: {key}')
837
+ elif '.positional_embedding' in key:
838
+ key = key.replace(
839
+ '.positional_embedding',
840
+ '.embeddings.position_embedding.weight',
841
+ )
842
+ elif '.text_projection' in key:
843
+ key = None # 使われない???
844
+ elif '.logit_scale' in key:
845
+ key = None # 使われない???
846
+ elif '.token_embedding' in key:
847
+ key = key.replace(
848
+ '.token_embedding.weight', '.embeddings.token_embedding.weight'
849
+ )
850
+ elif '.ln_final' in key:
851
+ key = key.replace('.ln_final', '.final_layer_norm')
852
+ return key
853
+
854
+ keys = list(checkpoint.keys())
855
+ new_sd = {}
856
+ for key in keys:
857
+ # remove resblocks 23
858
+ if '.resblocks.23.' in key:
859
+ continue
860
+ new_key = convert_key(key)
861
+ if new_key is None:
862
+ continue
863
+ new_sd[new_key] = checkpoint[key]
864
+
865
+ # attnの変換
866
+ for key in keys:
867
+ if '.resblocks.23.' in key:
868
+ continue
869
+ if '.resblocks' in key and '.attn.in_proj_' in key:
870
+ # 三つに分割
871
+ values = torch.chunk(checkpoint[key], 3)
872
+
873
+ key_suffix = '.weight' if 'weight' in key else '.bias'
874
+ key_pfx = key.replace(
875
+ 'cond_stage_model.model.transformer.resblocks.',
876
+ 'text_model.encoder.layers.',
877
+ )
878
+ key_pfx = key_pfx.replace('_weight', '')
879
+ key_pfx = key_pfx.replace('_bias', '')
880
+ key_pfx = key_pfx.replace('.attn.in_proj', '.self_attn.')
881
+ new_sd[key_pfx + 'q_proj' + key_suffix] = values[0]
882
+ new_sd[key_pfx + 'k_proj' + key_suffix] = values[1]
883
+ new_sd[key_pfx + 'v_proj' + key_suffix] = values[2]
884
+
885
+ # position_idsの追加
886
+ new_sd['text_model.embeddings.position_ids'] = torch.Tensor(
887
+ [list(range(max_length))]
888
+ ).to(torch.int64)
889
+ return new_sd
890
+
891
+
892
+ # endregion
893
+
894
+
895
+ # region Diffusers->StableDiffusion の変換コード
896
+ # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
897
+
898
+
899
+ def conv_transformer_to_linear(checkpoint):
900
+ keys = list(checkpoint.keys())
901
+ tf_keys = ['proj_in.weight', 'proj_out.weight']
902
+ for key in keys:
903
+ if '.'.join(key.split('.')[-2:]) in tf_keys:
904
+ if checkpoint[key].ndim > 2:
905
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
906
+
907
+
908
+ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
909
+ unet_conversion_map = [
910
+ # (stable-diffusion, HF Diffusers)
911
+ ('time_embed.0.weight', 'time_embedding.linear_1.weight'),
912
+ ('time_embed.0.bias', 'time_embedding.linear_1.bias'),
913
+ ('time_embed.2.weight', 'time_embedding.linear_2.weight'),
914
+ ('time_embed.2.bias', 'time_embedding.linear_2.bias'),
915
+ ('input_blocks.0.0.weight', 'conv_in.weight'),
916
+ ('input_blocks.0.0.bias', 'conv_in.bias'),
917
+ ('out.0.weight', 'conv_norm_out.weight'),
918
+ ('out.0.bias', 'conv_norm_out.bias'),
919
+ ('out.2.weight', 'conv_out.weight'),
920
+ ('out.2.bias', 'conv_out.bias'),
921
+ ]
922
+
923
+ unet_conversion_map_resnet = [
924
+ # (stable-diffusion, HF Diffusers)
925
+ ('in_layers.0', 'norm1'),
926
+ ('in_layers.2', 'conv1'),
927
+ ('out_layers.0', 'norm2'),
928
+ ('out_layers.3', 'conv2'),
929
+ ('emb_layers.1', 'time_emb_proj'),
930
+ ('skip_connection', 'conv_shortcut'),
931
+ ]
932
+
933
+ unet_conversion_map_layer = []
934
+ for i in range(4):
935
+ # loop over downblocks/upblocks
936
+
937
+ for j in range(2):
938
+ # loop over resnets/attentions for downblocks
939
+ hf_down_res_prefix = f'down_blocks.{i}.resnets.{j}.'
940
+ sd_down_res_prefix = f'input_blocks.{3*i + j + 1}.0.'
941
+ unet_conversion_map_layer.append(
942
+ (sd_down_res_prefix, hf_down_res_prefix)
943
+ )
944
+
945
+ if i < 3:
946
+ # no attention layers in down_blocks.3
947
+ hf_down_atn_prefix = f'down_blocks.{i}.attentions.{j}.'
948
+ sd_down_atn_prefix = f'input_blocks.{3*i + j + 1}.1.'
949
+ unet_conversion_map_layer.append(
950
+ (sd_down_atn_prefix, hf_down_atn_prefix)
951
+ )
952
+
953
+ for j in range(3):
954
+ # loop over resnets/attentions for upblocks
955
+ hf_up_res_prefix = f'up_blocks.{i}.resnets.{j}.'
956
+ sd_up_res_prefix = f'output_blocks.{3*i + j}.0.'
957
+ unet_conversion_map_layer.append(
958
+ (sd_up_res_prefix, hf_up_res_prefix)
959
+ )
960
+
961
+ if i > 0:
962
+ # no attention layers in up_blocks.0
963
+ hf_up_atn_prefix = f'up_blocks.{i}.attentions.{j}.'
964
+ sd_up_atn_prefix = f'output_blocks.{3*i + j}.1.'
965
+ unet_conversion_map_layer.append(
966
+ (sd_up_atn_prefix, hf_up_atn_prefix)
967
+ )
968
+
969
+ if i < 3:
970
+ # no downsample in down_blocks.3
971
+ hf_downsample_prefix = f'down_blocks.{i}.downsamplers.0.conv.'
972
+ sd_downsample_prefix = f'input_blocks.{3*(i+1)}.0.op.'
973
+ unet_conversion_map_layer.append(
974
+ (sd_downsample_prefix, hf_downsample_prefix)
975
+ )
976
+
977
+ # no upsample in up_blocks.3
978
+ hf_upsample_prefix = f'up_blocks.{i}.upsamplers.0.'
979
+ sd_upsample_prefix = (
980
+ f'output_blocks.{3*i + 2}.{1 if i == 0 else 2}.'
981
+ )
982
+ unet_conversion_map_layer.append(
983
+ (sd_upsample_prefix, hf_upsample_prefix)
984
+ )
985
+
986
+ hf_mid_atn_prefix = 'mid_block.attentions.0.'
987
+ sd_mid_atn_prefix = 'middle_block.1.'
988
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
989
+
990
+ for j in range(2):
991
+ hf_mid_res_prefix = f'mid_block.resnets.{j}.'
992
+ sd_mid_res_prefix = f'middle_block.{2*j}.'
993
+ unet_conversion_map_layer.append(
994
+ (sd_mid_res_prefix, hf_mid_res_prefix)
995
+ )
996
+
997
+ # buyer beware: this is a *brittle* function,
998
+ # and correct output requires that all of these pieces interact in
999
+ # the exact order in which I have arranged them.
1000
+ mapping = {k: k for k in unet_state_dict.keys()}
1001
+ for sd_name, hf_name in unet_conversion_map:
1002
+ mapping[hf_name] = sd_name
1003
+ for k, v in mapping.items():
1004
+ if 'resnets' in k:
1005
+ for sd_part, hf_part in unet_conversion_map_resnet:
1006
+ v = v.replace(hf_part, sd_part)
1007
+ mapping[k] = v
1008
+ for k, v in mapping.items():
1009
+ for sd_part, hf_part in unet_conversion_map_layer:
1010
+ v = v.replace(hf_part, sd_part)
1011
+ mapping[k] = v
1012
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
1013
+
1014
+ if v2:
1015
+ conv_transformer_to_linear(new_state_dict)
1016
+
1017
+ return new_state_dict
1018
+
1019
+
1020
+ # ================#
1021
+ # VAE Conversion #
1022
+ # ================#
1023
+
1024
+
1025
+ def reshape_weight_for_sd(w):
1026
+ # convert HF linear weights to SD conv2d weights
1027
+ return w.reshape(*w.shape, 1, 1)
1028
+
1029
+
1030
+ def convert_vae_state_dict(vae_state_dict):
1031
+ vae_conversion_map = [
1032
+ # (stable-diffusion, HF Diffusers)
1033
+ ('nin_shortcut', 'conv_shortcut'),
1034
+ ('norm_out', 'conv_norm_out'),
1035
+ ('mid.attn_1.', 'mid_block.attentions.0.'),
1036
+ ]
1037
+
1038
+ for i in range(4):
1039
+ # down_blocks have two resnets
1040
+ for j in range(2):
1041
+ hf_down_prefix = f'encoder.down_blocks.{i}.resnets.{j}.'
1042
+ sd_down_prefix = f'encoder.down.{i}.block.{j}.'
1043
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
1044
+
1045
+ if i < 3:
1046
+ hf_downsample_prefix = f'down_blocks.{i}.downsamplers.0.'
1047
+ sd_downsample_prefix = f'down.{i}.downsample.'
1048
+ vae_conversion_map.append(
1049
+ (sd_downsample_prefix, hf_downsample_prefix)
1050
+ )
1051
+
1052
+ hf_upsample_prefix = f'up_blocks.{i}.upsamplers.0.'
1053
+ sd_upsample_prefix = f'up.{3-i}.upsample.'
1054
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
1055
+
1056
+ # up_blocks have three resnets
1057
+ # also, up blocks in hf are numbered in reverse from sd
1058
+ for j in range(3):
1059
+ hf_up_prefix = f'decoder.up_blocks.{i}.resnets.{j}.'
1060
+ sd_up_prefix = f'decoder.up.{3-i}.block.{j}.'
1061
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
1062
+
1063
+ # this part accounts for mid blocks in both the encoder and the decoder
1064
+ for i in range(2):
1065
+ hf_mid_res_prefix = f'mid_block.resnets.{i}.'
1066
+ sd_mid_res_prefix = f'mid.block_{i+1}.'
1067
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
1068
+
1069
+ vae_conversion_map_attn = [
1070
+ # (stable-diffusion, HF Diffusers)
1071
+ ('norm.', 'group_norm.'),
1072
+ ('q.', 'query.'),
1073
+ ('k.', 'key.'),
1074
+ ('v.', 'value.'),
1075
+ ('proj_out.', 'proj_attn.'),
1076
+ ]
1077
+
1078
+ mapping = {k: k for k in vae_state_dict.keys()}
1079
+ for k, v in mapping.items():
1080
+ for sd_part, hf_part in vae_conversion_map:
1081
+ v = v.replace(hf_part, sd_part)
1082
+ mapping[k] = v
1083
+ for k, v in mapping.items():
1084
+ if 'attentions' in k:
1085
+ for sd_part, hf_part in vae_conversion_map_attn:
1086
+ v = v.replace(hf_part, sd_part)
1087
+ mapping[k] = v
1088
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
1089
+ weights_to_convert = ['q', 'k', 'v', 'proj_out']
1090
+ for k, v in new_state_dict.items():
1091
+ for weight_name in weights_to_convert:
1092
+ if f'mid.attn_1.{weight_name}.weight' in k:
1093
+ # print(f"Reshaping {k} for SD format")
1094
+ new_state_dict[k] = reshape_weight_for_sd(v)
1095
+
1096
+ return new_state_dict
1097
+
1098
+
1099
+ # endregion
1100
+
1101
+ # region 自作のモデル読み書きなど
1102
+
1103
+
1104
+ def is_safetensors(path):
1105
+ return os.path.splitext(path)[1].lower() == '.safetensors'
1106
+
1107
+
1108
+ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
1109
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
1110
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
1111
+ (
1112
+ 'cond_stage_model.transformer.embeddings.',
1113
+ 'cond_stage_model.transformer.text_model.embeddings.',
1114
+ ),
1115
+ (
1116
+ 'cond_stage_model.transformer.encoder.',
1117
+ 'cond_stage_model.transformer.text_model.encoder.',
1118
+ ),
1119
+ (
1120
+ 'cond_stage_model.transformer.final_layer_norm.',
1121
+ 'cond_stage_model.transformer.text_model.final_layer_norm.',
1122
+ ),
1123
+ ]
1124
+
1125
+ if is_safetensors(ckpt_path):
1126
+ checkpoint = None
1127
+ state_dict = load_file(ckpt_path, 'cpu')
1128
+ else:
1129
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
1130
+ if 'state_dict' in checkpoint:
1131
+ state_dict = checkpoint['state_dict']
1132
+ else:
1133
+ state_dict = checkpoint
1134
+ checkpoint = None
1135
+
1136
+ key_reps = []
1137
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
1138
+ for key in state_dict.keys():
1139
+ if key.startswith(rep_from):
1140
+ new_key = rep_to + key[len(rep_from) :]
1141
+ key_reps.append((key, new_key))
1142
+
1143
+ for key, new_key in key_reps:
1144
+ state_dict[new_key] = state_dict[key]
1145
+ del state_dict[key]
1146
+
1147
+ return checkpoint, state_dict
1148
+
1149
+
1150
+ # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
1151
+ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
1152
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
1153
+ if dtype is not None:
1154
+ for k, v in state_dict.items():
1155
+ if type(v) is torch.Tensor:
1156
+ state_dict[k] = v.to(dtype)
1157
+
1158
+ # Convert the UNet2DConditionModel model.
1159
+ unet_config = create_unet_diffusers_config(v2)
1160
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
1161
+ v2, state_dict, unet_config
1162
+ )
1163
+
1164
+ unet = UNet2DConditionModel(**unet_config)
1165
+ info = unet.load_state_dict(converted_unet_checkpoint)
1166
+ print('loading u-net:', info)
1167
+
1168
+ # Convert the VAE model.
1169
+ vae_config = create_vae_diffusers_config()
1170
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(
1171
+ state_dict, vae_config
1172
+ )
1173
+
1174
+ vae = AutoencoderKL(**vae_config)
1175
+ info = vae.load_state_dict(converted_vae_checkpoint)
1176
+ print('loadint vae:', info)
1177
+
1178
+ # convert text_model
1179
+ if v2:
1180
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(
1181
+ state_dict, 77
1182
+ )
1183
+ cfg = CLIPTextConfig(
1184
+ vocab_size=49408,
1185
+ hidden_size=1024,
1186
+ intermediate_size=4096,
1187
+ num_hidden_layers=23,
1188
+ num_attention_heads=16,
1189
+ max_position_embeddings=77,
1190
+ hidden_act='gelu',
1191
+ layer_norm_eps=1e-05,
1192
+ dropout=0.0,
1193
+ attention_dropout=0.0,
1194
+ initializer_range=0.02,
1195
+ initializer_factor=1.0,
1196
+ pad_token_id=1,
1197
+ bos_token_id=0,
1198
+ eos_token_id=2,
1199
+ model_type='clip_text_model',
1200
+ projection_dim=512,
1201
+ torch_dtype='float32',
1202
+ transformers_version='4.25.0.dev0',
1203
+ )
1204
+ text_model = CLIPTextModel._from_config(cfg)
1205
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
1206
+ else:
1207
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(
1208
+ state_dict
1209
+ )
1210
+ text_model = CLIPTextModel.from_pretrained(
1211
+ 'openai/clip-vit-large-patch14'
1212
+ )
1213
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
1214
+ print('loading text encoder:', info)
1215
+
1216
+ return text_model, vae, unet
1217
+
1218
+
1219
+ def convert_text_encoder_state_dict_to_sd_v2(
1220
+ checkpoint, make_dummy_weights=False
1221
+ ):
1222
+ def convert_key(key):
1223
+ # position_idsの除去
1224
+ if '.position_ids' in key:
1225
+ return None
1226
+
1227
+ # common
1228
+ key = key.replace('text_model.encoder.', 'transformer.')
1229
+ key = key.replace('text_model.', '')
1230
+ if 'layers' in key:
1231
+ # resblocks conversion
1232
+ key = key.replace('.layers.', '.resblocks.')
1233
+ if '.layer_norm' in key:
1234
+ key = key.replace('.layer_norm', '.ln_')
1235
+ elif '.mlp.' in key:
1236
+ key = key.replace('.fc1.', '.c_fc.')
1237
+ key = key.replace('.fc2.', '.c_proj.')
1238
+ elif '.self_attn.out_proj' in key:
1239
+ key = key.replace('.self_attn.out_proj.', '.attn.out_proj.')
1240
+ elif '.self_attn.' in key:
1241
+ key = None # 特殊なので後で処理する
1242
+ else:
1243
+ raise ValueError(f'unexpected key in DiffUsers model: {key}')
1244
+ elif '.position_embedding' in key:
1245
+ key = key.replace(
1246
+ 'embeddings.position_embedding.weight', 'positional_embedding'
1247
+ )
1248
+ elif '.token_embedding' in key:
1249
+ key = key.replace(
1250
+ 'embeddings.token_embedding.weight', 'token_embedding.weight'
1251
+ )
1252
+ elif 'final_layer_norm' in key:
1253
+ key = key.replace('final_layer_norm', 'ln_final')
1254
+ return key
1255
+
1256
+ keys = list(checkpoint.keys())
1257
+ new_sd = {}
1258
+ for key in keys:
1259
+ new_key = convert_key(key)
1260
+ if new_key is None:
1261
+ continue
1262
+ new_sd[new_key] = checkpoint[key]
1263
+
1264
+ # attnの変換
1265
+ for key in keys:
1266
+ if 'layers' in key and 'q_proj' in key:
1267
+ # 三つを結合
1268
+ key_q = key
1269
+ key_k = key.replace('q_proj', 'k_proj')
1270
+ key_v = key.replace('q_proj', 'v_proj')
1271
+
1272
+ value_q = checkpoint[key_q]
1273
+ value_k = checkpoint[key_k]
1274
+ value_v = checkpoint[key_v]
1275
+ value = torch.cat([value_q, value_k, value_v])
1276
+
1277
+ new_key = key.replace(
1278
+ 'text_model.encoder.layers.', 'transformer.resblocks.'
1279
+ )
1280
+ new_key = new_key.replace('.self_attn.q_proj.', '.attn.in_proj_')
1281
+ new_sd[new_key] = value
1282
+
1283
+ # 最後の層などを捏造するか
1284
+ if make_dummy_weights:
1285
+ print(
1286
+ 'make dummy weights for resblock.23, text_projection and logit scale.'
1287
+ )
1288
+ keys = list(new_sd.keys())
1289
+ for key in keys:
1290
+ if key.startswith('transformer.resblocks.22.'):
1291
+ new_sd[key.replace('.22.', '.23.')] = new_sd[
1292
+ key
1293
+ ].clone() # copyしないとsafetensorsの保存で落ちる
1294
+
1295
+ # Diffusersに含まれない重みを作っておく
1296
+ new_sd['text_projection'] = torch.ones(
1297
+ (1024, 1024),
1298
+ dtype=new_sd[keys[0]].dtype,
1299
+ device=new_sd[keys[0]].device,
1300
+ )
1301
+ new_sd['logit_scale'] = torch.tensor(1)
1302
+
1303
+ return new_sd
1304
+
1305
+
1306
+ def save_stable_diffusion_checkpoint(
1307
+ v2,
1308
+ output_file,
1309
+ text_encoder,
1310
+ unet,
1311
+ ckpt_path,
1312
+ epochs,
1313
+ steps,
1314
+ save_dtype=None,
1315
+ vae=None,
1316
+ ):
1317
+ if ckpt_path is not None:
1318
+ # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
1319
+ checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(
1320
+ ckpt_path
1321
+ )
1322
+ if checkpoint is None: # safetensors または state_dictのckpt
1323
+ checkpoint = {}
1324
+ strict = False
1325
+ else:
1326
+ strict = True
1327
+ if 'state_dict' in state_dict:
1328
+ del state_dict['state_dict']
1329
+ else:
1330
+ # 新しく作る
1331
+ assert (
1332
+ vae is not None
1333
+ ), 'VAE is required to save a checkpoint without a given checkpoint'
1334
+ checkpoint = {}
1335
+ state_dict = {}
1336
+ strict = False
1337
+
1338
+ def update_sd(prefix, sd):
1339
+ for k, v in sd.items():
1340
+ key = prefix + k
1341
+ assert (
1342
+ not strict or key in state_dict
1343
+ ), f'Illegal key in save SD: {key}'
1344
+ if save_dtype is not None:
1345
+ v = v.detach().clone().to('cpu').to(save_dtype)
1346
+ state_dict[key] = v
1347
+
1348
+ # Convert the UNet model
1349
+ unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1350
+ update_sd('model.diffusion_model.', unet_state_dict)
1351
+
1352
+ # Convert the text encoder model
1353
+ if v2:
1354
+ make_dummy = (
1355
+ ckpt_path is None
1356
+ ) # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
1357
+ text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(
1358
+ text_encoder.state_dict(), make_dummy
1359
+ )
1360
+ update_sd('cond_stage_model.model.', text_enc_dict)
1361
+ else:
1362
+ text_enc_dict = text_encoder.state_dict()
1363
+ update_sd('cond_stage_model.transformer.', text_enc_dict)
1364
+
1365
+ # Convert the VAE
1366
+ if vae is not None:
1367
+ vae_dict = convert_vae_state_dict(vae.state_dict())
1368
+ update_sd('first_stage_model.', vae_dict)
1369
+
1370
+ # Put together new checkpoint
1371
+ key_count = len(state_dict.keys())
1372
+ new_ckpt = {'state_dict': state_dict}
1373
+
1374
+ if 'epoch' in checkpoint:
1375
+ epochs += checkpoint['epoch']
1376
+ if 'global_step' in checkpoint:
1377
+ steps += checkpoint['global_step']
1378
+
1379
+ new_ckpt['epoch'] = epochs
1380
+ new_ckpt['global_step'] = steps
1381
+
1382
+ if is_safetensors(output_file):
1383
+ # TODO Tensor以外のdictの値を削除したほうがいいか
1384
+ save_file(state_dict, output_file)
1385
+ else:
1386
+ torch.save(new_ckpt, output_file)
1387
+
1388
+ return key_count
1389
+
1390
+
1391
+ def save_diffusers_checkpoint(
1392
+ v2,
1393
+ output_dir,
1394
+ text_encoder,
1395
+ unet,
1396
+ pretrained_model_name_or_path,
1397
+ vae=None,
1398
+ use_safetensors=False,
1399
+ ):
1400
+ if pretrained_model_name_or_path is None:
1401
+ # load default settings for v1/v2
1402
+ if v2:
1403
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
1404
+ else:
1405
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
1406
+
1407
+ scheduler = DDIMScheduler.from_pretrained(
1408
+ pretrained_model_name_or_path, subfolder='scheduler'
1409
+ )
1410
+ tokenizer = CLIPTokenizer.from_pretrained(
1411
+ pretrained_model_name_or_path, subfolder='tokenizer'
1412
+ )
1413
+ if vae is None:
1414
+ vae = AutoencoderKL.from_pretrained(
1415
+ pretrained_model_name_or_path, subfolder='vae'
1416
+ )
1417
+
1418
+ pipeline = StableDiffusionPipeline(
1419
+ unet=unet,
1420
+ text_encoder=text_encoder,
1421
+ vae=vae,
1422
+ scheduler=scheduler,
1423
+ tokenizer=tokenizer,
1424
+ safety_checker=None,
1425
+ feature_extractor=None,
1426
+ requires_safety_checker=None,
1427
+ )
1428
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
1429
+
1430
+
1431
+ VAE_PREFIX = 'first_stage_model.'
1432
+
1433
+
1434
+ def load_vae(vae_id, dtype):
1435
+ print(f'load VAE: {vae_id}')
1436
+ if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
1437
+ # Diffusers local/remote
1438
+ try:
1439
+ vae = AutoencoderKL.from_pretrained(
1440
+ vae_id, subfolder=None, torch_dtype=dtype
1441
+ )
1442
+ except EnvironmentError as e:
1443
+ print(f'exception occurs in loading vae: {e}')
1444
+ print("retry with subfolder='vae'")
1445
+ vae = AutoencoderKL.from_pretrained(
1446
+ vae_id, subfolder='vae', torch_dtype=dtype
1447
+ )
1448
+ return vae
1449
+
1450
+ # local
1451
+ vae_config = create_vae_diffusers_config()
1452
+
1453
+ if vae_id.endswith('.bin'):
1454
+ # SD 1.5 VAE on Huggingface
1455
+ vae_sd = torch.load(vae_id, map_location='cpu')
1456
+ converted_vae_checkpoint = vae_sd
1457
+ else:
1458
+ # StableDiffusion
1459
+ vae_model = torch.load(vae_id, map_location='cpu')
1460
+ vae_sd = vae_model['state_dict']
1461
+
1462
+ # vae only or full model
1463
+ full_model = False
1464
+ for vae_key in vae_sd:
1465
+ if vae_key.startswith(VAE_PREFIX):
1466
+ full_model = True
1467
+ break
1468
+ if not full_model:
1469
+ sd = {}
1470
+ for key, value in vae_sd.items():
1471
+ sd[VAE_PREFIX + key] = value
1472
+ vae_sd = sd
1473
+ del sd
1474
+
1475
+ # Convert the VAE model.
1476
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(
1477
+ vae_sd, vae_config
1478
+ )
1479
+
1480
+ vae = AutoencoderKL(**vae_config)
1481
+ vae.load_state_dict(converted_vae_checkpoint)
1482
+ return vae
1483
+
1484
+
1485
+ def get_epoch_ckpt_name(use_safetensors, epoch):
1486
+ return f'epoch-{epoch:06d}' + (
1487
+ '.safetensors' if use_safetensors else '.ckpt'
1488
+ )
1489
+
1490
+
1491
+ def get_last_ckpt_name(use_safetensors):
1492
+ return f'last' + ('.safetensors' if use_safetensors else '.ckpt')
1493
+
1494
+
1495
+ # endregion
1496
+
1497
+
1498
+ def make_bucket_resolutions(
1499
+ max_reso, min_size=256, max_size=1024, divisible=64
1500
+ ):
1501
+ max_width, max_height = max_reso
1502
+ max_area = (max_width // divisible) * (max_height // divisible)
1503
+
1504
+ resos = set()
1505
+
1506
+ size = int(math.sqrt(max_area)) * divisible
1507
+ resos.add((size, size))
1508
+
1509
+ size = min_size
1510
+ while size <= max_size:
1511
+ width = size
1512
+ height = min(max_size, (max_area // (width // divisible)) * divisible)
1513
+ resos.add((width, height))
1514
+ resos.add((height, width))
1515
+
1516
+ # # make additional resos
1517
+ # if width >= height and width - divisible >= min_size:
1518
+ # resos.add((width - divisible, height))
1519
+ # resos.add((height, width - divisible))
1520
+ # if height >= width and height - divisible >= min_size:
1521
+ # resos.add((width, height - divisible))
1522
+ # resos.add((height - divisible, width))
1523
+
1524
+ size += divisible
1525
+
1526
+ resos = list(resos)
1527
+ resos.sort()
1528
+
1529
+ aspect_ratios = [w / h for w, h in resos]
1530
+ return resos, aspect_ratios
1531
+
1532
+
1533
+ if __name__ == '__main__':
1534
+ resos, aspect_ratios = make_bucket_resolutions((512, 768))
1535
+ print(len(resos))
1536
+ print(resos)
1537
+ print(aspect_ratios)
1538
+
1539
+ ars = set()
1540
+ for ar in aspect_ratios:
1541
+ if ar in ars:
1542
+ print('error! duplicate ar:', ar)
1543
+ ars.add(ar)
StableTuner_RunPod_Fix/trainer.py ADDED
@@ -0,0 +1,1750 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright 2022 HuggingFace, ShivamShrirao
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+ import keyboard
17
+ import gradio as gr
18
+ import argparse
19
+ import random
20
+ import hashlib
21
+ import itertools
22
+ import json
23
+ import math
24
+ import os
25
+ import copy
26
+ from contextlib import nullcontext
27
+ from pathlib import Path
28
+ import shutil
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+ import numpy as np
33
+ from accelerate import Accelerator
34
+ from accelerate.logging import get_logger
35
+ from accelerate.utils import set_seed
36
+ from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel,DiffusionPipeline, DPMSolverMultistepScheduler,EulerDiscreteScheduler
37
+ from diffusers.optimization import get_scheduler
38
+ from torchvision.transforms import functional
39
+ from tqdm.auto import tqdm
40
+ from transformers import CLIPTextModel, CLIPTokenizer
41
+ from typing import Dict, List, Generator, Tuple
42
+ from PIL import Image, ImageFile
43
+ from diffusers.utils.import_utils import is_xformers_available
44
+ from trainer_util import *
45
+ from dataloaders_util import *
46
+ from discriminator import Discriminator2D
47
+ from lion_pytorch import Lion
48
+ logger = get_logger(__name__)
49
+ def parse_args():
50
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
51
+ parser.add_argument(
52
+ "--revision",
53
+ type=str,
54
+ default=None,
55
+ required=False,
56
+ help="Revision of pretrained model identifier from huggingface.co/models.",
57
+ )
58
+
59
+ parser.add_argument(
60
+ "--attention",
61
+ type=str,
62
+ choices=["xformers", "flash_attention"],
63
+ default="xformers",
64
+ help="Type of attention to use."
65
+ )
66
+ parser.add_argument(
67
+ "--model_variant",
68
+ type=str,
69
+ default='base',
70
+ required=False,
71
+ help="Train Base/Inpaint/Depth2Img",
72
+ )
73
+ parser.add_argument(
74
+ "--aspect_mode",
75
+ type=str,
76
+ default='dynamic',
77
+ required=False,
78
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
79
+ )
80
+ parser.add_argument(
81
+ "--aspect_mode_action_preference",
82
+ type=str,
83
+ default='add',
84
+ required=False,
85
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
86
+ )
87
+ parser.add_argument('--use_lion',default=False,action="store_true", help='Use the new LION optimizer')
88
+ parser.add_argument('--use_ema',default=False,action="store_true", help='Use EMA for finetuning')
89
+ parser.add_argument('--clip_penultimate',default=False,action="store_true", help='Use penultimate CLIP layer for text embedding')
90
+ parser.add_argument("--conditional_dropout", type=float, default=None,required=False, help="Conditional dropout probability")
91
+ parser.add_argument('--disable_cudnn_benchmark', default=False, action="store_true")
92
+ parser.add_argument('--use_text_files_as_captions', default=False, action="store_true")
93
+
94
+ parser.add_argument(
95
+ "--sample_from_batch",
96
+ type=int,
97
+ default=0,
98
+ help=("Number of prompts to sample from the batch for inference"),
99
+ )
100
+ parser.add_argument(
101
+ "--flatten_sample_folder",
102
+ default=True,
103
+ action="store_true",
104
+ help="Will save samples in one folder instead of per-epoch",
105
+ )
106
+ parser.add_argument(
107
+ "--stop_text_encoder_training",
108
+ type=int,
109
+ default=999999999999999,
110
+ help=("The epoch at which the text_encoder is no longer trained"),
111
+ )
112
+ parser.add_argument(
113
+ "--use_bucketing",
114
+ default=False,
115
+ action="store_true",
116
+ help="Will save and generate samples before training",
117
+ )
118
+ parser.add_argument(
119
+ "--regenerate_latent_cache",
120
+ default=False,
121
+ action="store_true",
122
+ help="Will save and generate samples before training",
123
+ )
124
+ parser.add_argument(
125
+ "--sample_on_training_start",
126
+ default=False,
127
+ action="store_true",
128
+ help="Will save and generate samples before training",
129
+ )
130
+
131
+ parser.add_argument(
132
+ "--add_class_images_to_dataset",
133
+ default=False,
134
+ action="store_true",
135
+ help="will generate and add class images to the dataset without using prior reservation in training",
136
+ )
137
+ parser.add_argument(
138
+ "--auto_balance_concept_datasets",
139
+ default=False,
140
+ action="store_true",
141
+ help="will balance the number of images in each concept dataset to match the minimum number of images in any concept dataset",
142
+ )
143
+ parser.add_argument(
144
+ "--sample_aspect_ratios",
145
+ default=False,
146
+ action="store_true",
147
+ help="sample different aspect ratios for each image",
148
+ )
149
+ parser.add_argument(
150
+ "--dataset_repeats",
151
+ type=int,
152
+ default=1,
153
+ help="repeat the dataset this many times",
154
+ )
155
+ parser.add_argument(
156
+ "--save_every_n_epoch",
157
+ type=int,
158
+ default=1,
159
+ help="save on epoch finished",
160
+ )
161
+ parser.add_argument(
162
+ "--pretrained_model_name_or_path",
163
+ type=str,
164
+ default=None,
165
+ required=True,
166
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
167
+ )
168
+ parser.add_argument(
169
+ "--pretrained_vae_name_or_path",
170
+ type=str,
171
+ default=None,
172
+ help="Path to pretrained vae or vae identifier from huggingface.co/models.",
173
+ )
174
+ parser.add_argument(
175
+ "--tokenizer_name",
176
+ type=str,
177
+ default=None,
178
+ help="Pretrained tokenizer name or path if not the same as model_name",
179
+ )
180
+ parser.add_argument(
181
+ "--instance_data_dir",
182
+ type=str,
183
+ default=None,
184
+ help="A folder containing the training data of instance images.",
185
+ )
186
+ parser.add_argument(
187
+ "--class_data_dir",
188
+ type=str,
189
+ default=None,
190
+ help="A folder containing the training data of class images.",
191
+ )
192
+ parser.add_argument(
193
+ "--instance_prompt",
194
+ type=str,
195
+ default=None,
196
+ help="The prompt with identifier specifying the instance",
197
+ )
198
+ parser.add_argument(
199
+ "--class_prompt",
200
+ type=str,
201
+ default=None,
202
+ help="The prompt to specify images in the same class as provided instance images.",
203
+ )
204
+ parser.add_argument(
205
+ "--save_sample_prompt",
206
+ type=str,
207
+ default=None,
208
+ help="The prompt used to generate sample outputs to save.",
209
+ )
210
+ parser.add_argument(
211
+ "--n_save_sample",
212
+ type=int,
213
+ default=4,
214
+ help="The number of samples to save.",
215
+ )
216
+ parser.add_argument(
217
+ "--sample_height",
218
+ type=int,
219
+ default=512,
220
+ help="The number of samples to save.",
221
+ )
222
+ parser.add_argument(
223
+ "--sample_width",
224
+ type=int,
225
+ default=512,
226
+ help="The number of samples to save.",
227
+ )
228
+ parser.add_argument(
229
+ "--save_guidance_scale",
230
+ type=float,
231
+ default=7.5,
232
+ help="CFG for save sample.",
233
+ )
234
+ parser.add_argument(
235
+ "--save_infer_steps",
236
+ type=int,
237
+ default=30,
238
+ help="The number of inference steps for save sample.",
239
+ )
240
+ parser.add_argument(
241
+ "--with_prior_preservation",
242
+ default=False,
243
+ action="store_true",
244
+ help="Flag to add prior preservation loss.",
245
+ )
246
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
247
+ parser.add_argument(
248
+ "--with_offset_noise",
249
+ default=False,
250
+ action="store_true",
251
+ help="Flag to offset noise applied to latents.",
252
+ )
253
+
254
+ parser.add_argument("--offset_noise_weight", type=float, default=0.1, help="The weight of offset noise applied during training.")
255
+ parser.add_argument(
256
+ "--num_class_images",
257
+ type=int,
258
+ default=100,
259
+ help=(
260
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
261
+ " sampled with class_prompt."
262
+ ),
263
+ )
264
+ parser.add_argument(
265
+ "--output_dir",
266
+ type=str,
267
+ default="text-inversion-model",
268
+ help="The output directory where the model predictions and checkpoints will be written.",
269
+ )
270
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
271
+ parser.add_argument(
272
+ "--resolution",
273
+ type=int,
274
+ default=512,
275
+ help=(
276
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
277
+ " resolution"
278
+ ),
279
+ )
280
+ parser.add_argument(
281
+ "--center_crop", default=False, action="store_true", help="Whether to center crop images before resizing to resolution"
282
+ )
283
+ parser.add_argument("--train_text_encoder", default=False, action="store_true", help="Whether to train the text encoder")
284
+ parser.add_argument(
285
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
286
+ )
287
+ parser.add_argument(
288
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
289
+ )
290
+ parser.add_argument("--num_train_epochs", type=int, default=1)
291
+ parser.add_argument(
292
+ "--max_train_steps",
293
+ type=int,
294
+ default=None,
295
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
296
+ )
297
+ parser.add_argument(
298
+ "--gradient_accumulation_steps",
299
+ type=int,
300
+ default=1,
301
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
302
+ )
303
+ parser.add_argument(
304
+ "--gradient_checkpointing",
305
+ default=False,
306
+ action="store_true",
307
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
308
+ )
309
+ parser.add_argument(
310
+ "--learning_rate",
311
+ type=float,
312
+ default=5e-6,
313
+ help="Initial learning rate (after the potential warmup period) to use.",
314
+ )
315
+ parser.add_argument(
316
+ "--scale_lr",
317
+ action="store_true",
318
+ default=False,
319
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
320
+ )
321
+ parser.add_argument(
322
+ "--lr_scheduler",
323
+ type=str,
324
+ default="constant",
325
+ help=(
326
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
327
+ ' "constant", "constant_with_warmup"]'
328
+ ),
329
+ )
330
+ parser.add_argument(
331
+ "--lr_warmup_steps", type=float, default=500, help="Number of steps for the warmup in the lr scheduler."
332
+ )
333
+ parser.add_argument(
334
+ "--use_8bit_adam", default=False, action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
335
+ )
336
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
337
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
338
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
339
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
340
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
341
+ parser.add_argument("--push_to_hub", default=False, action="store_true", help="Whether or not to push the model to the Hub.")
342
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
343
+ parser.add_argument(
344
+ "--hub_model_id",
345
+ type=str,
346
+ default=None,
347
+ help="The name of the repository to keep in sync with the local `output_dir`.",
348
+ )
349
+ parser.add_argument(
350
+ "--logging_dir",
351
+ type=str,
352
+ default="logs",
353
+ help=(
354
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
355
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
356
+ ),
357
+ )
358
+ parser.add_argument("--log_interval", type=int, default=10, help="Log every N steps.")
359
+ parser.add_argument("--sample_step_interval", type=int, default=100000000000000, help="Sample images every N steps.")
360
+ parser.add_argument(
361
+ "--mixed_precision",
362
+ type=str,
363
+ default="no",
364
+ choices=["no", "fp16", "bf16","tf32"],
365
+ help=(
366
+ "Whether to use mixed precision. Choose"
367
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
368
+ "and an Nvidia Ampere GPU."
369
+ ),
370
+ )
371
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
372
+ parser.add_argument(
373
+ "--concepts_list",
374
+ type=str,
375
+ default=None,
376
+ help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.",
377
+ )
378
+ parser.add_argument("--save_sample_controlled_seed", type=int, action='append', help="Set a seed for an extra sample image to be constantly saved.")
379
+ parser.add_argument("--detect_full_drive", default=True, action="store_true", help="Delete checkpoints when the drive is full.")
380
+ parser.add_argument("--send_telegram_updates", default=False, action="store_true", help="Send Telegram updates.")
381
+ parser.add_argument("--telegram_chat_id", type=str, default="0", help="Telegram chat ID.")
382
+ parser.add_argument("--telegram_token", type=str, default="0", help="Telegram token.")
383
+ parser.add_argument("--use_deepspeed_adam", default=False, action="store_true", help="Use experimental DeepSpeed Adam 8.")
384
+ parser.add_argument('--append_sample_controlled_seed_action', action='append')
385
+ parser.add_argument('--add_sample_prompt', type=str, action='append')
386
+ parser.add_argument('--use_image_names_as_captions', default=False, action="store_true")
387
+ parser.add_argument('--shuffle_captions', default=False, action="store_true")
388
+ parser.add_argument("--masked_training", default=False, required=False, action='store_true', help="Whether to mask parts of the image during training")
389
+ parser.add_argument("--normalize_masked_area_loss", default=False, required=False, action='store_true', help="Normalize the loss, to make it independent of the size of the masked area")
390
+ parser.add_argument("--unmasked_probability", type=float, default=1, required=False, help="Probability of training a step without a mask")
391
+ parser.add_argument("--max_denoising_strength", type=float, default=1, required=False, help="Max denoising steps to train on")
392
+ parser.add_argument('--add_mask_prompt', type=str, default=None, action="append", dest="mask_prompts", help="Prompt for automatic mask creation")
393
+ parser.add_argument('--with_gan', default=False, action="store_true", help="Use GAN (experimental)")
394
+ parser.add_argument("--gan_weight", type=float, default=0.2, required=False, help="Strength of effect GAN has on training")
395
+ parser.add_argument("--gan_warmup", type=float, default=0, required=False, help="Slowly increases GAN weight from zero over this many steps, useful when initializing a GAN discriminator from scratch")
396
+ parser.add_argument('--discriminator_config', default="configs/discriminator_large.json", help="Location of config file to use when initializing a new GAN discriminator")
397
+ parser.add_argument('--sample_from_ema', default=True, action="store_true", help="Generate sample images using the EMA model")
398
+ parser.add_argument('--run_name', type=str, default=None, help="Adds a custom identifier to the sample and checkpoint directories")
399
+ args = parser.parse_args()
400
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
401
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
402
+ args.local_rank = env_local_rank
403
+
404
+ return args
405
+
406
+ def main():
407
+ print(f" {bcolors.OKBLUE}Booting Up StableTuner{bcolors.ENDC}")
408
+ print(f" {bcolors.OKBLUE}Please wait a moment as we load up some stuff...{bcolors.ENDC}")
409
+ #torch.cuda.set_per_process_memory_fraction(0.5)
410
+ args = parse_args()
411
+ #temp arg
412
+ args.batch_tokens = None
413
+ if args.disable_cudnn_benchmark:
414
+ torch.backends.cudnn.benchmark = False
415
+ else:
416
+ torch.backends.cudnn.benchmark = True
417
+ if args.send_telegram_updates:
418
+ send_telegram_message(f"Booting up StableTuner!\n", args.telegram_chat_id, args.telegram_token)
419
+ logging_dir = Path(args.output_dir, "logs", args.logging_dir)
420
+ if args.run_name:
421
+ main_sample_dir = os.path.join(args.output_dir, f"samples_{args.run_name}")
422
+ else:
423
+ main_sample_dir = os.path.join(args.output_dir, "samples")
424
+ if os.path.exists(main_sample_dir):
425
+ shutil.rmtree(main_sample_dir)
426
+ os.makedirs(main_sample_dir)
427
+ #create logging directory
428
+ if not logging_dir.exists():
429
+ logging_dir.mkdir(parents=True)
430
+ #create output directory
431
+ if not Path(args.output_dir).exists():
432
+ Path(args.output_dir).mkdir(parents=True)
433
+
434
+
435
+ accelerator = Accelerator(
436
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
437
+ mixed_precision=args.mixed_precision if args.mixed_precision != 'tf32' else 'no',
438
+ log_with="tensorboard",
439
+ logging_dir=logging_dir,
440
+ )
441
+
442
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
443
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
444
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
445
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
446
+ raise ValueError(
447
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
448
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
449
+ )
450
+
451
+ if args.seed is not None:
452
+ set_seed(args.seed)
453
+
454
+ if args.concepts_list is None:
455
+ args.concepts_list = [
456
+ {
457
+ "instance_prompt": args.instance_prompt,
458
+ "class_prompt": args.class_prompt,
459
+ "instance_data_dir": args.instance_data_dir,
460
+ "class_data_dir": args.class_data_dir
461
+ }
462
+ ]
463
+ else:
464
+ with open(args.concepts_list, "r") as f:
465
+ args.concepts_list = json.load(f)
466
+
467
+ if args.with_prior_preservation or args.add_class_images_to_dataset:
468
+ pipeline = None
469
+ for concept in args.concepts_list:
470
+ class_images_dir = Path(concept["class_data_dir"])
471
+ class_images_dir.mkdir(parents=True, exist_ok=True)
472
+ cur_class_images = len(list(class_images_dir.iterdir()))
473
+
474
+ if cur_class_images < args.num_class_images:
475
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
476
+ if pipeline is None:
477
+
478
+ pipeline = DiffusionPipeline.from_pretrained(
479
+ args.pretrained_model_name_or_path,
480
+ safety_checker=None,
481
+ vae=AutoencoderKL.from_pretrained(args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,subfolder=None if args.pretrained_vae_name_or_path else "vae" ,safe_serialization=True),
482
+ torch_dtype=torch_dtype,
483
+ requires_safety_checker=False,
484
+ )
485
+ pipeline.set_progress_bar_config(disable=True)
486
+ pipeline.to(accelerator.device)
487
+
488
+ #if args.use_bucketing == False:
489
+ num_new_images = args.num_class_images - cur_class_images
490
+ logger.info(f"Number of class images to sample: {num_new_images}.")
491
+
492
+ sample_dataset = PromptDataset(concept["class_prompt"], num_new_images)
493
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
494
+ sample_dataloader = accelerator.prepare(sample_dataloader)
495
+ #else:
496
+ #create class images that match up to the concept target buckets
497
+ # instance_images_dir = Path(concept["instance_data_dir"])
498
+ # cur_instance_images = len(list(instance_images_dir.iterdir()))
499
+ #target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
500
+ # num_new_images = cur_instance_images - cur_class_images
501
+
502
+
503
+
504
+ with torch.autocast("cuda"):
505
+ for example in tqdm(
506
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
507
+ ):
508
+ with torch.autocast("cuda"):
509
+ images = pipeline(example["prompt"],height=args.resolution,width=args.resolution).images
510
+ for i, image in enumerate(images):
511
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
512
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
513
+ image.save(image_filename)
514
+
515
+ del pipeline
516
+ if torch.cuda.is_available():
517
+ torch.cuda.empty_cache()
518
+ torch.cuda.ipc_collect()
519
+ # Load the tokenizer
520
+ if args.tokenizer_name:
521
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name )
522
+ elif args.pretrained_model_name_or_path:
523
+ #print(os.getcwd())
524
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer" )
525
+
526
+ # Load models and create wrapper for stable diffusion
527
+ #text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder" )
528
+ text_encoder = CLIPTextModel.from_pretrained(
529
+ args.pretrained_model_name_or_path,
530
+ subfolder="text_encoder",
531
+ revision=args.revision,
532
+ )
533
+ vae = AutoencoderKL.from_pretrained(
534
+ args.pretrained_model_name_or_path,
535
+ subfolder="vae",
536
+ revision=args.revision,
537
+ )
538
+ unet = UNet2DConditionModel.from_pretrained(
539
+ args.pretrained_model_name_or_path,
540
+ subfolder="unet",
541
+ revision=args.revision,
542
+ torch_dtype=torch.float32
543
+ )
544
+
545
+ if args.with_gan:
546
+ if os.path.isdir(os.path.join(args.pretrained_model_name_or_path, "discriminator")):
547
+ discriminator = Discriminator2D.from_pretrained(
548
+ args.pretrained_model_name_or_path,
549
+ subfolder="discriminator",
550
+ revision=args.revision,
551
+ )
552
+ else:
553
+ print(f" {bcolors.WARNING}Discriminator network (GAN) not found. Initializing a new network. It may take a very large number of steps to train.{bcolors.ENDC}")
554
+ if not args.gan_warmup:
555
+ print(f" {bcolors.WARNING}Consider using --gan_warmup to stabilize the model while the discriminator is being trained.{bcolors.ENDC}")
556
+ with open(args.discriminator_config, "r") as f:
557
+ discriminator_config = json.load(f)
558
+ discriminator = Discriminator2D.from_config(discriminator_config)
559
+
560
+
561
+ if is_xformers_available() and args.attention=='xformers':
562
+ try:
563
+ vae.enable_xformers_memory_efficient_attention()
564
+ unet.enable_xformers_memory_efficient_attention()
565
+ if args.with_gan:
566
+ discriminator.enable_xformers_memory_efficient_attention()
567
+ except Exception as e:
568
+ logger.warning(
569
+ "Could not enable memory efficient attention. Make sure xformers is installed"
570
+ f" correctly and a GPU is available: {e}"
571
+ )
572
+ elif args.attention=='flash_attention':
573
+ replace_unet_cross_attn_to_flash_attention()
574
+
575
+ if args.use_ema == True:
576
+ if os.path.isdir(os.path.join(args.pretrained_model_name_or_path, "unet_ema")):
577
+ ema_unet = UNet2DConditionModel.from_pretrained(
578
+ args.pretrained_model_name_or_path,
579
+ subfolder="unet_ema",
580
+ revision=args.revision,
581
+ torch_dtype=torch.float32
582
+ )
583
+ else:
584
+ ema_unet = copy.deepcopy(unet)
585
+ ema_unet.config["step"] = 0
586
+ for param in ema_unet.parameters():
587
+ param.requires_grad = False
588
+
589
+ if args.model_variant == "depth2img":
590
+ d2i = Depth2Img(unet,text_encoder,args.mixed_precision,args.pretrained_model_name_or_path,accelerator)
591
+ vae.requires_grad_(False)
592
+ vae.enable_slicing()
593
+ if not args.train_text_encoder:
594
+ text_encoder.requires_grad_(False)
595
+
596
+ if args.gradient_checkpointing:
597
+ unet.enable_gradient_checkpointing()
598
+ if args.train_text_encoder:
599
+ text_encoder.gradient_checkpointing_enable()
600
+ if args.with_gan:
601
+ discriminator.enable_gradient_checkpointing()
602
+
603
+ if args.scale_lr:
604
+ args.learning_rate = (
605
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
606
+ )
607
+
608
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
609
+ if args.use_8bit_adam and args.use_deepspeed_adam==False and args.use_lion==False:
610
+ try:
611
+ import bitsandbytes as bnb
612
+ except ImportError:
613
+ raise ImportError(
614
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
615
+ )
616
+ optimizer_class = bnb.optim.AdamW8bit
617
+ print("Using 8-bit Adam")
618
+ elif args.use_8bit_adam and args.use_deepspeed_adam==True:
619
+ try:
620
+ from deepspeed.ops.adam import DeepSpeedCPUAdam
621
+ except ImportError:
622
+ raise ImportError(
623
+ "To use 8-bit DeepSpeed Adam, try updating your cuda and deepspeed integrations."
624
+ )
625
+ optimizer_class = DeepSpeedCPUAdam
626
+ elif args.use_lion == True:
627
+ print("Using LION optimizer")
628
+ optimizer_class = Lion
629
+ elif args.use_deepspeed_adam==False and args.use_lion==False and args.use_8bit_adam==False:
630
+ optimizer_class = torch.optim.AdamW
631
+ params_to_optimize = (
632
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
633
+ )
634
+ if args.use_lion == False:
635
+ optimizer = optimizer_class(
636
+ params_to_optimize,
637
+ lr=args.learning_rate,
638
+ betas=(args.adam_beta1, args.adam_beta2),
639
+ weight_decay=args.adam_weight_decay,
640
+ eps=args.adam_epsilon,
641
+ )
642
+ if args.with_gan:
643
+ optimizer_discriminator = optimizer_class(
644
+ discriminator.parameters(),
645
+ lr=args.learning_rate,
646
+ betas=(args.adam_beta1, args.adam_beta2),
647
+ weight_decay=args.adam_weight_decay,
648
+ eps=args.adam_epsilon,
649
+ )
650
+ else:
651
+ optimizer = optimizer_class(
652
+ params_to_optimize,
653
+ lr=args.learning_rate,
654
+ betas=(args.adam_beta1, args.adam_beta2),
655
+ weight_decay=args.adam_weight_decay,
656
+ #eps=args.adam_epsilon,
657
+ )
658
+ if args.with_gan:
659
+ optimizer_discriminator = optimizer_class(
660
+ discriminator.parameters(),
661
+ lr=args.learning_rate,
662
+ betas=(args.adam_beta1, args.adam_beta2),
663
+ weight_decay=args.adam_weight_decay,
664
+ #eps=args.adam_epsilon,
665
+ )
666
+ noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
667
+
668
+ if args.use_bucketing:
669
+ train_dataset = AutoBucketing(
670
+ concepts_list=args.concepts_list,
671
+ use_image_names_as_captions=args.use_image_names_as_captions,
672
+ shuffle_captions=args.shuffle_captions,
673
+ batch_size=args.train_batch_size,
674
+ tokenizer=tokenizer,
675
+ add_class_images_to_dataset=args.add_class_images_to_dataset,
676
+ balance_datasets=args.auto_balance_concept_datasets,
677
+ resolution=args.resolution,
678
+ with_prior_loss=False,#args.with_prior_preservation,
679
+ repeats=args.dataset_repeats,
680
+ use_text_files_as_captions=args.use_text_files_as_captions,
681
+ aspect_mode=args.aspect_mode,
682
+ action_preference=args.aspect_mode_action_preference,
683
+ seed=args.seed,
684
+ model_variant=args.model_variant,
685
+ extra_module=None if args.model_variant != "depth2img" else d2i,
686
+ mask_prompts=args.mask_prompts,
687
+ load_mask=args.masked_training,
688
+ )
689
+ else:
690
+ train_dataset = NormalDataset(
691
+ concepts_list=args.concepts_list,
692
+ tokenizer=tokenizer,
693
+ with_prior_preservation=args.with_prior_preservation,
694
+ size=args.resolution,
695
+ center_crop=args.center_crop,
696
+ num_class_images=args.num_class_images,
697
+ use_image_names_as_captions=args.use_image_names_as_captions,
698
+ shuffle_captions=args.shuffle_captions,
699
+ repeats=args.dataset_repeats,
700
+ use_text_files_as_captions=args.use_text_files_as_captions,
701
+ seed = args.seed,
702
+ model_variant=args.model_variant,
703
+ extra_module=None if args.model_variant != "depth2img" else d2i,
704
+ mask_prompts=args.mask_prompts,
705
+ load_mask=args.masked_training,
706
+ )
707
+ def collate_fn(examples):
708
+ #print(examples)
709
+ #print('test')
710
+ input_ids = [example["instance_prompt_ids"] for example in examples]
711
+ tokens = input_ids
712
+ pixel_values = [example["instance_images"] for example in examples]
713
+ mask = None
714
+ if "mask" in examples[0]:
715
+ mask = [example["mask"] for example in examples]
716
+ if args.model_variant == 'depth2img':
717
+ depth = [example["instance_depth_images"] for example in examples]
718
+
719
+ #print('test')
720
+ # Concat class and instance examples for prior preservation.
721
+ # We do this to avoid doing two forward passes.
722
+ if args.with_prior_preservation:
723
+ input_ids += [example["class_prompt_ids"] for example in examples]
724
+ pixel_values += [example["class_images"] for example in examples]
725
+ if "mask" in examples[0]:
726
+ mask += [example["class_mask"] for example in examples]
727
+ if args.model_variant == 'depth2img':
728
+ depth = [example["class_depth_images"] for example in examples]
729
+ mask_values = None
730
+ if mask is not None:
731
+ mask_values = torch.stack(mask)
732
+ mask_values = mask_values.to(memory_format=torch.contiguous_format).float()
733
+ if args.model_variant == 'depth2img':
734
+ depth_values = torch.stack(depth)
735
+ depth_values = depth_values.to(memory_format=torch.contiguous_format).float()
736
+ ### no need to do it now when it's loaded by the multiAspectsDataset
737
+ #if args.with_prior_preservation:
738
+ # input_ids += [example["class_prompt_ids"] for example in examples]
739
+ # pixel_values += [example["class_images"] for example in examples]
740
+
741
+ #print(pixel_values)
742
+ #unpack the pixel_values from tensor to list
743
+
744
+
745
+ pixel_values = torch.stack(pixel_values)
746
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
747
+ input_ids = tokenizer.pad(
748
+ {"input_ids": input_ids},
749
+ padding="max_length",
750
+ max_length=tokenizer.model_max_length,
751
+ return_tensors="pt",\
752
+ ).input_ids
753
+
754
+ extra_values = None
755
+ if args.model_variant == 'depth2img':
756
+ extra_values = depth_values
757
+
758
+ return {
759
+ "input_ids": input_ids,
760
+ "pixel_values": pixel_values,
761
+ "extra_values": extra_values,
762
+ "mask_values": mask_values,
763
+ "tokens": tokens
764
+ }
765
+
766
+ train_dataloader = torch.utils.data.DataLoader(
767
+ train_dataset, batch_size=args.train_batch_size, shuffle=False, collate_fn=collate_fn, pin_memory=True
768
+ )
769
+ #get the length of the dataset
770
+ train_dataset_length = len(train_dataset)
771
+ #code to check if latent cache needs to be resaved
772
+ #check if last_run.json file exists in logging_dir
773
+ if os.path.exists(logging_dir / "last_run.json"):
774
+ #if it exists, load it
775
+ with open(logging_dir / "last_run.json", "r") as f:
776
+ last_run = json.load(f)
777
+ last_run_batch_size = last_run["batch_size"]
778
+ last_run_dataset_length = last_run["dataset_length"]
779
+ if last_run_batch_size != args.train_batch_size:
780
+ print(f" {bcolors.WARNING}The batch_size has changed since the last run. Regenerating Latent Cache.{bcolors.ENDC}")
781
+
782
+ args.regenerate_latent_cache = True
783
+ #save the new batch_size and dataset_length to last_run.json
784
+ if last_run_dataset_length != train_dataset_length:
785
+ print(f" {bcolors.WARNING}The dataset length has changed since the last run. Regenerating Latent Cache.{bcolors.ENDC}")
786
+
787
+ args.regenerate_latent_cache = True
788
+ #save the new batch_size and dataset_length to last_run.json
789
+ with open(logging_dir / "last_run.json", "w") as f:
790
+ json.dump({"batch_size": args.train_batch_size, "dataset_length": train_dataset_length}, f)
791
+
792
+ else:
793
+ #if it doesn't exist, create it
794
+ last_run = {"batch_size": args.train_batch_size, "dataset_length": train_dataset_length}
795
+ #create the file
796
+ with open(logging_dir / "last_run.json", "w") as f:
797
+ json.dump(last_run, f)
798
+
799
+ weight_dtype = torch.float32
800
+ if accelerator.mixed_precision == "fp16":
801
+ print("Using fp16")
802
+ weight_dtype = torch.float16
803
+ elif accelerator.mixed_precision == "bf16":
804
+ print("Using bf16")
805
+ weight_dtype = torch.bfloat16
806
+ elif args.mixed_precision == "tf32":
807
+ torch.backends.cuda.matmul.allow_tf32 = True
808
+ #torch.set_float32_matmul_precision("medium")
809
+
810
+ # Move text_encode and vae to gpu.
811
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
812
+ # as these models are only used for inference, keeping weights in full precision is not required.
813
+ vae.to(accelerator.device, dtype=weight_dtype)
814
+ if args.use_ema == True:
815
+ ema_unet.to(accelerator.device)
816
+ if not args.train_text_encoder:
817
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
818
+
819
+ if args.use_bucketing:
820
+ wh = set([tuple(x.target_wh) for x in train_dataset.image_train_items])
821
+ else:
822
+ wh = set([tuple([args.resolution, args.resolution]) for x in train_dataset.image_paths])
823
+ full_mask_by_aspect = {shape: vae.encode(torch.zeros(1, 3, shape[1], shape[0]).to(accelerator.device, dtype=weight_dtype)).latent_dist.mean * 0.18215 for shape in wh}
824
+
825
+ cached_dataset = CachedLatentsDataset(batch_size=args.train_batch_size,
826
+ text_encoder=text_encoder,
827
+ tokenizer=tokenizer,
828
+ dtype=weight_dtype,
829
+ model_variant=args.model_variant,
830
+ shuffle_per_epoch="False",
831
+ args = args,)
832
+
833
+ gen_cache = False
834
+ data_len = len(train_dataloader)
835
+ latent_cache_dir = Path(args.output_dir, "logs", "latent_cache")
836
+ #check if latents_cache.pt exists in the output_dir
837
+ if not os.path.exists(latent_cache_dir):
838
+ os.makedirs(latent_cache_dir)
839
+ for i in range(0,data_len-1):
840
+ if not os.path.exists(os.path.join(latent_cache_dir, f"latents_cache_{i}.pt")):
841
+ gen_cache = True
842
+ break
843
+ if args.regenerate_latent_cache == True:
844
+ files = os.listdir(latent_cache_dir)
845
+ gen_cache = True
846
+ for file in files:
847
+ os.remove(os.path.join(latent_cache_dir,file))
848
+ if gen_cache == False :
849
+ print(f" {bcolors.OKGREEN}Loading Latent Cache from {latent_cache_dir}{bcolors.ENDC}")
850
+ del vae
851
+ if not args.train_text_encoder:
852
+ del text_encoder
853
+ if torch.cuda.is_available():
854
+ torch.cuda.empty_cache()
855
+ torch.cuda.ipc_collect()
856
+ #load all the cached latents into a single dataset
857
+ for i in range(0,data_len-1):
858
+ cached_dataset.add_pt_cache(os.path.join(latent_cache_dir,f"latents_cache_{i}.pt"))
859
+ if gen_cache == True:
860
+ #delete all the cached latents if they exist to avoid problems
861
+ print(f" {bcolors.WARNING}Generating latents cache...{bcolors.ENDC}")
862
+ train_dataset = LatentsDataset([], [], [], [], [], [])
863
+ counter = 0
864
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
865
+ with torch.no_grad():
866
+ for batch in tqdm(train_dataloader, desc="Caching latents", bar_format='%s{l_bar}%s%s{bar}%s%s{r_bar}%s'%(bcolors.OKBLUE,bcolors.ENDC, bcolors.OKBLUE, bcolors.ENDC,bcolors.OKBLUE,bcolors.ENDC,)):
867
+ cached_extra = None
868
+ cached_mask = None
869
+ batch["pixel_values"] = batch["pixel_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype)
870
+ batch["input_ids"] = batch["input_ids"].to(accelerator.device, non_blocking=True)
871
+ cached_latent = vae.encode(batch["pixel_values"]).latent_dist
872
+ if batch["mask_values"] is not None:
873
+ cached_mask = functional.resize(batch["mask_values"], size=cached_latent.mean.shape[2:])
874
+ if batch["mask_values"] is not None and args.model_variant == "inpainting":
875
+ batch["mask_values"] = batch["mask_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype)
876
+ cached_extra = vae.encode(batch["pixel_values"] * (1 - batch["mask_values"])).latent_dist
877
+ if args.model_variant == "depth2img":
878
+ batch["extra_values"] = batch["extra_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype)
879
+ cached_extra = functional.resize(batch["extra_values"], size=cached_latent.mean.shape[2:])
880
+ if args.train_text_encoder:
881
+ cached_text_enc = batch["input_ids"]
882
+ else:
883
+ cached_text_enc = text_encoder(batch["input_ids"])[0]
884
+ train_dataset.add_latent(cached_latent, cached_text_enc, cached_mask, cached_extra, batch["tokens"])
885
+ del batch
886
+ del cached_latent
887
+ del cached_text_enc
888
+ del cached_mask
889
+ del cached_extra
890
+ torch.save(train_dataset, os.path.join(latent_cache_dir,f"latents_cache_{counter}.pt"))
891
+ cached_dataset.add_pt_cache(os.path.join(latent_cache_dir,f"latents_cache_{counter}.pt"))
892
+ counter += 1
893
+ train_dataset = LatentsDataset([], [], [], [], [], [])
894
+ #if counter % 300 == 0:
895
+ #train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=False)
896
+ # gc.collect()
897
+ # torch.cuda.empty_cache()
898
+ # accelerator.free_memory()
899
+
900
+ #clear vram after caching latents
901
+ del vae
902
+ if not args.train_text_encoder:
903
+ del text_encoder
904
+ if torch.cuda.is_available():
905
+ torch.cuda.empty_cache()
906
+ torch.cuda.ipc_collect()
907
+ #load all the cached latents into a single dataset
908
+ train_dataloader = torch.utils.data.DataLoader(cached_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=False)
909
+ print(f" {bcolors.OKGREEN}Latents are ready.{bcolors.ENDC}")
910
+ # Scheduler and math around the number of training steps.
911
+ overrode_max_train_steps = False
912
+ num_update_steps_per_epoch = len(train_dataloader)
913
+ if args.max_train_steps is None:
914
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
915
+ overrode_max_train_steps = True
916
+
917
+ if args.lr_warmup_steps < 1:
918
+ args.lr_warmup_steps = math.floor(args.lr_warmup_steps * args.max_train_steps / args.gradient_accumulation_steps)
919
+
920
+ lr_scheduler = get_scheduler(
921
+ args.lr_scheduler,
922
+ optimizer=optimizer,
923
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
924
+ num_training_steps=args.max_train_steps,
925
+ )
926
+
927
+ if args.train_text_encoder and not args.use_ema:
928
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
929
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
930
+ )
931
+ elif args.train_text_encoder and args.use_ema:
932
+ unet, text_encoder, ema_unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
933
+ unet, text_encoder, ema_unet, optimizer, train_dataloader, lr_scheduler
934
+ )
935
+ elif not args.train_text_encoder and args.use_ema:
936
+ unet, ema_unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
937
+ unet, ema_unet, optimizer, train_dataloader, lr_scheduler
938
+ )
939
+ elif not args.train_text_encoder and not args.use_ema:
940
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
941
+ unet, optimizer, train_dataloader, lr_scheduler
942
+ )
943
+ if args.with_gan:
944
+ lr_scheduler_discriminator = get_scheduler(
945
+ args.lr_scheduler,
946
+ optimizer=optimizer_discriminator,
947
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
948
+ num_training_steps=args.max_train_steps,
949
+ )
950
+ discriminator, optimizer_discriminator, lr_scheduler_discriminator = accelerator.prepare(discriminator, optimizer_discriminator, lr_scheduler_discriminator)
951
+
952
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
953
+ num_update_steps_per_epoch = len(train_dataloader)
954
+ if overrode_max_train_steps:
955
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
956
+ #print(args.max_train_steps, num_update_steps_per_epoch)
957
+ # Afterwards we recalculate our number of training epochs
958
+ #print(args.max_train_steps, num_update_steps_per_epoch)
959
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
960
+
961
+ # We need to initialize the trackers we use, and also store our configuration.
962
+ # The trackers initializes automatically on the main process.
963
+ if accelerator.is_main_process:
964
+ accelerator.init_trackers("dreambooth")
965
+ # Train!
966
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
967
+
968
+ logger.info("***** Running training *****")
969
+ logger.info(f" Num examples = {len(train_dataset)}")
970
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
971
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
972
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
973
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
974
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
975
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
976
+ def mid_train_playground(step):
977
+
978
+ tqdm.write(f"{bcolors.WARNING} Booting up GUI{bcolors.ENDC}")
979
+ epoch = step // num_update_steps_per_epoch
980
+ if args.train_text_encoder and args.stop_text_encoder_training == True:
981
+ text_enc_model = accelerator.unwrap_model(text_encoder,True)
982
+ elif args.train_text_encoder and args.stop_text_encoder_training > epoch:
983
+ text_enc_model = accelerator.unwrap_model(text_encoder,True)
984
+ elif args.train_text_encoder == False:
985
+ text_enc_model = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder" )
986
+ elif args.train_text_encoder and args.stop_text_encoder_training <= epoch:
987
+ if 'frozen_directory' in locals():
988
+ text_enc_model = CLIPTextModel.from_pretrained(frozen_directory, subfolder="text_encoder")
989
+ else:
990
+ text_enc_model = accelerator.unwrap_model(text_encoder,True)
991
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
992
+ unwrapped_unet = accelerator.unwrap_model(ema_unet if args.use_ema else unet,True)
993
+
994
+ pipeline = DiffusionPipeline.from_pretrained(
995
+ args.pretrained_model_name_or_path,
996
+ unet=unwrapped_unet,
997
+ text_encoder=text_enc_model,
998
+ vae=AutoencoderKL.from_pretrained(args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,subfolder=None if args.pretrained_vae_name_or_path else "vae", safe_serialization=True),
999
+ safety_checker=None,
1000
+ torch_dtype=weight_dtype,
1001
+ local_files_only=False,
1002
+ requires_safety_checker=False,
1003
+ )
1004
+ pipeline.scheduler = scheduler
1005
+ if is_xformers_available() and args.attention=='xformers':
1006
+ try:
1007
+ vae.enable_xformers_memory_efficient_attention()
1008
+ unet.enable_xformers_memory_efficient_attention()
1009
+ except Exception as e:
1010
+ logger.warning(
1011
+ "Could not enable memory efficient attention. Make sure xformers is installed"
1012
+ f" correctly and a GPU is available: {e}"
1013
+ )
1014
+ elif args.attention=='flash_attention':
1015
+ replace_unet_cross_attn_to_flash_attention()
1016
+ pipeline = pipeline.to(accelerator.device)
1017
+ def inference(prompt, negative_prompt, num_samples, height=512, width=512, num_inference_steps=50,seed=-1,guidance_scale=7.5):
1018
+ with torch.autocast("cuda"), torch.inference_mode():
1019
+ if seed != -1:
1020
+ if g_cuda is None:
1021
+ g_cuda = torch.Generator(device='cuda')
1022
+ else:
1023
+ g_cuda.manual_seed(int(seed))
1024
+ else:
1025
+ seed = random.randint(0, 100000)
1026
+ g_cuda = torch.Generator(device='cuda')
1027
+ g_cuda.manual_seed(seed)
1028
+ return pipeline(
1029
+ prompt, height=int(height), width=int(width),
1030
+ negative_prompt=negative_prompt,
1031
+ num_images_per_prompt=int(num_samples),
1032
+ num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,
1033
+ generator=g_cuda).images, seed
1034
+
1035
+ with gr.Blocks() as demo:
1036
+ with gr.Row():
1037
+ with gr.Column():
1038
+ prompt = gr.Textbox(label="Prompt", value="photo of zwx dog in a bucket")
1039
+ negative_prompt = gr.Textbox(label="Negative Prompt", value="")
1040
+ run = gr.Button(value="Generate")
1041
+ with gr.Row():
1042
+ num_samples = gr.Number(label="Number of Samples", value=4)
1043
+ guidance_scale = gr.Number(label="Guidance Scale", value=7.5)
1044
+ with gr.Row():
1045
+ height = gr.Number(label="Height", value=512)
1046
+ width = gr.Number(label="Width", value=512)
1047
+ with gr.Row():
1048
+ num_inference_steps = gr.Slider(label="Steps", value=25)
1049
+ seed = gr.Number(label="Seed", value=-1)
1050
+ with gr.Column():
1051
+ gallery = gr.Gallery()
1052
+ seedDisplay = gr.Number(label="Used Seed:", value=0)
1053
+
1054
+ run.click(inference, inputs=[prompt, negative_prompt, num_samples, height, width, num_inference_steps,seed, guidance_scale], outputs=[gallery,seedDisplay])
1055
+
1056
+ demo.launch(share=True,prevent_thread_lock=True)
1057
+ tqdm.write(f"{bcolors.WARNING}Gradio Session is active, Press 'F12' to resume training{bcolors.ENDC}")
1058
+ keyboard.wait('f12')
1059
+ demo.close()
1060
+ del demo
1061
+ del text_enc_model
1062
+ del unwrapped_unet
1063
+ del pipeline
1064
+ return
1065
+
1066
+ def save_and_sample_weights(step,context='checkpoint',save_model=True):
1067
+ try:
1068
+ #check how many folders are in the output dir
1069
+ #if there are more than 5, delete the oldest one
1070
+ #save the model
1071
+ #save the optimizer
1072
+ #save the lr_scheduler
1073
+ #save the args
1074
+ height = args.sample_height
1075
+ width = args.sample_width
1076
+ batch_prompts = []
1077
+ if args.sample_from_batch > 0:
1078
+ num_samples = args.sample_from_batch if args.sample_from_batch < args.train_batch_size else args.train_batch_size
1079
+ batch_prompts = []
1080
+ tokens = args.batch_tokens
1081
+ if tokens != None:
1082
+ allPrompts = list(set([tokenizer.decode(p).replace('<|endoftext|>','').replace('<|startoftext|>', '') for p in tokens]))
1083
+ if len(allPrompts) < num_samples:
1084
+ num_samples = len(allPrompts)
1085
+ batch_prompts = random.sample(allPrompts, num_samples)
1086
+
1087
+
1088
+ if args.sample_aspect_ratios:
1089
+ #choose random aspect ratio from ASPECTS
1090
+ aspect_ratio = random.choice(ASPECTS)
1091
+ height = aspect_ratio[0]
1092
+ width = aspect_ratio[1]
1093
+ if os.path.exists(args.output_dir):
1094
+ if args.detect_full_drive==True:
1095
+ folders = os.listdir(args.output_dir)
1096
+ #check how much space is left on the drive
1097
+ total, used, free = shutil.disk_usage("/")
1098
+ if (free // (2**30)) < 4:
1099
+ #folders.remove("0")
1100
+ #get the folder with the lowest number
1101
+ #oldest_folder = min(folder for folder in folders if folder.isdigit())
1102
+ tqdm.write(f"{bcolors.FAIL}Drive is almost full, Please make some space to continue training.{bcolors.ENDC}")
1103
+ if args.send_telegram_updates:
1104
+ try:
1105
+ send_telegram_message(f"Drive is almost full, Please make some space to continue training.", args.telegram_chat_id, args.telegram_token)
1106
+ except:
1107
+ pass
1108
+ #count time
1109
+ import time
1110
+ start_time = time.time()
1111
+ import platform
1112
+ while input("Press Enter to continue... if you're on linux we'll wait 5 minutes for you to make space and continue"):
1113
+ #check if five minutes have passed
1114
+ #check if os is linux
1115
+ if 'Linux' in platform.platform():
1116
+ if time.time() - start_time > 300:
1117
+ break
1118
+
1119
+
1120
+ #oldest_folder_path = os.path.join(args.output_dir, oldest_folder)
1121
+ #shutil.rmtree(oldest_folder_path)
1122
+ # Create the pipeline using using the trained modules and save it.
1123
+ if accelerator.is_main_process:
1124
+ if 'step' in context:
1125
+ #what is the current epoch
1126
+ epoch = step // num_update_steps_per_epoch
1127
+ else:
1128
+ epoch = step
1129
+ if args.train_text_encoder and args.stop_text_encoder_training == True:
1130
+ text_enc_model = accelerator.unwrap_model(text_encoder,True)
1131
+ elif args.train_text_encoder and args.stop_text_encoder_training > epoch:
1132
+ text_enc_model = accelerator.unwrap_model(text_encoder,True)
1133
+ elif args.train_text_encoder == False:
1134
+ text_enc_model = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder" )
1135
+ elif args.train_text_encoder and args.stop_text_encoder_training <= epoch:
1136
+ if 'frozen_directory' in locals():
1137
+ text_enc_model = CLIPTextModel.from_pretrained(frozen_directory, subfolder="text_encoder")
1138
+ else:
1139
+ text_enc_model = accelerator.unwrap_model(text_encoder,True)
1140
+
1141
+ #scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
1142
+ #scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler", prediction_type="v_prediction")
1143
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
1144
+ unwrapped_unet = accelerator.unwrap_model(unet,True)
1145
+
1146
+ pipeline = DiffusionPipeline.from_pretrained(
1147
+ args.pretrained_model_name_or_path,
1148
+ unet=unwrapped_unet,
1149
+ text_encoder=text_enc_model,
1150
+ vae=AutoencoderKL.from_pretrained(args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,subfolder=None if args.pretrained_vae_name_or_path else "vae",),
1151
+ safety_checker=None,
1152
+ torch_dtype=weight_dtype,
1153
+ local_files_only=False,
1154
+ requires_safety_checker=False,
1155
+ )
1156
+ pipeline.scheduler = scheduler
1157
+ if is_xformers_available() and args.attention=='xformers':
1158
+ try:
1159
+ unet.enable_xformers_memory_efficient_attention()
1160
+ except Exception as e:
1161
+ logger.warning(
1162
+ "Could not enable memory efficient attention. Make sure xformers is installed"
1163
+ f" correctly and a GPU is available: {e}"
1164
+ )
1165
+ elif args.attention=='flash_attention':
1166
+ replace_unet_cross_attn_to_flash_attention()
1167
+ if args.run_name:
1168
+ save_dir = os.path.join(args.output_dir, f"{context}_{step}_{args.run_name}")
1169
+ else:
1170
+ save_dir = os.path.join(args.output_dir, f"{context}_{step}")
1171
+ if args.flatten_sample_folder:
1172
+ sample_dir = main_sample_dir
1173
+ else:
1174
+ sample_dir = os.path.join(main_sample_dir, f"{context}_{step}")
1175
+ #if sample dir path does not exist, create it
1176
+
1177
+ if args.stop_text_encoder_training == True:
1178
+ save_dir = frozen_directory
1179
+ if save_model:
1180
+ pipeline.save_pretrained(save_dir,safe_serialization=True)
1181
+ if args.with_gan:
1182
+ discriminator.save_pretrained(os.path.join(save_dir, "discriminator"), safe_serialization=True)
1183
+ if args.use_ema:
1184
+ ema_unet.save_pretrained(os.path.join(save_dir, "unet_ema"), safe_serialization=True)
1185
+ with open(os.path.join(save_dir, "args.json"), "w") as f:
1186
+ json.dump(args.__dict__, f, indent=2)
1187
+ if args.stop_text_encoder_training == True:
1188
+ #delete every folder in frozen_directory but the text encoder
1189
+ for folder in os.listdir(save_dir):
1190
+ if folder != "text_encoder" and os.path.isdir(os.path.join(save_dir, folder)):
1191
+ shutil.rmtree(os.path.join(save_dir, folder))
1192
+ imgs = []
1193
+ if args.use_ema and args.sample_from_ema:
1194
+ pipeline.unet = ema_unet
1195
+
1196
+ for param in unet.parameters():
1197
+ param.requires_grad = False
1198
+ if torch.cuda.is_available():
1199
+ torch.cuda.empty_cache()
1200
+ torch.cuda.ipc_collect()
1201
+
1202
+ if args.add_sample_prompt is not None or batch_prompts != [] and args.stop_text_encoder_training != True:
1203
+ prompts = []
1204
+ if args.add_sample_prompt is not None:
1205
+ for prompt in args.add_sample_prompt:
1206
+ prompts.append(prompt)
1207
+ if batch_prompts != []:
1208
+ for prompt in batch_prompts:
1209
+ prompts.append(prompt)
1210
+
1211
+ pipeline = pipeline.to(accelerator.device)
1212
+ pipeline.set_progress_bar_config(disable=True)
1213
+ #sample_dir = os.path.join(save_dir, "samples")
1214
+ #if sample_dir exists, delete it
1215
+ if os.path.exists(sample_dir):
1216
+ if not args.flatten_sample_folder:
1217
+ shutil.rmtree(sample_dir)
1218
+ os.makedirs(sample_dir, exist_ok=True)
1219
+ with torch.autocast("cuda"), torch.inference_mode():
1220
+ if args.send_telegram_updates:
1221
+ try:
1222
+ send_telegram_message(f"Generating samples for <b>{step}</b> {context}", args.telegram_chat_id, args.telegram_token)
1223
+ except:
1224
+ pass
1225
+ n_sample = args.n_save_sample
1226
+ if args.save_sample_controlled_seed:
1227
+ n_sample += len(args.save_sample_controlled_seed)
1228
+ progress_bar_sample = tqdm(total=len(prompts)*n_sample,desc="Generating samples")
1229
+ for samplePrompt in prompts:
1230
+ sampleIndex = prompts.index(samplePrompt)
1231
+ #convert sampleIndex to number in words
1232
+ # Data to be written
1233
+ sampleProperties = {
1234
+ "samplePrompt" : samplePrompt
1235
+ }
1236
+
1237
+ # Serializing json
1238
+ json_object = json.dumps(sampleProperties, indent=4)
1239
+
1240
+ if args.flatten_sample_folder:
1241
+ sampleName = f"{context}_{step}_prompt_{sampleIndex+1}"
1242
+ else:
1243
+ sampleName = f"prompt_{sampleIndex+1}"
1244
+
1245
+ if not args.flatten_sample_folder:
1246
+ os.makedirs(os.path.join(sample_dir,sampleName), exist_ok=True)
1247
+
1248
+ if args.model_variant == 'inpainting':
1249
+ conditioning_image = torch.zeros(1, 3, height, width)
1250
+ mask = torch.ones(1, 1, height, width)
1251
+ if args.model_variant == 'depth2img':
1252
+ #pil new white image
1253
+ test_image = Image.new('RGB', (width, height), (255, 255, 255))
1254
+ depth_image = Image.new('RGB', (width, height), (255, 255, 255))
1255
+ depth = np.array(depth_image.convert("L"))
1256
+ depth = depth.astype(np.float32) / 255.0
1257
+ depth = depth[None, None]
1258
+ depth = torch.from_numpy(depth)
1259
+ for i in range(n_sample):
1260
+ #check if the sample is controlled by a seed
1261
+ if i < args.n_save_sample:
1262
+ if args.model_variant == 'inpainting':
1263
+ images = pipeline(samplePrompt, conditioning_image, mask, height=height,width=width, guidance_scale=args.save_guidance_scale, num_inference_steps=args.save_infer_steps).images
1264
+ if args.model_variant == 'depth2img':
1265
+ images = pipeline(samplePrompt,image=test_image, guidance_scale=args.save_guidance_scale, num_inference_steps=args.save_infer_steps,strength=1.0).images
1266
+ elif args.model_variant == 'base':
1267
+ images = pipeline(samplePrompt,height=height,width=width, guidance_scale=args.save_guidance_scale, num_inference_steps=args.save_infer_steps).images
1268
+
1269
+ if not args.flatten_sample_folder:
1270
+ images[0].save(os.path.join(sample_dir,sampleName, f"{sampleName}_{i}.png"))
1271
+ else:
1272
+ images[0].save(os.path.join(sample_dir, f"{sampleName}_{i}.png"))
1273
+
1274
+ else:
1275
+ seed = args.save_sample_controlled_seed[i - args.n_save_sample]
1276
+ generator = torch.Generator("cuda").manual_seed(seed)
1277
+ if args.model_variant == 'inpainting':
1278
+ images = pipeline(samplePrompt,conditioning_image, mask,height=height,width=width, guidance_scale=args.save_guidance_scale, num_inference_steps=args.save_infer_steps, generator=generator).images
1279
+ if args.model_variant == 'depth2img':
1280
+ images = pipeline(samplePrompt,image=test_image, guidance_scale=args.save_guidance_scale, num_inference_steps=args.save_infer_steps,generator=generator,strength=1.0).images
1281
+ elif args.model_variant == 'base':
1282
+ images = pipeline(samplePrompt,height=height,width=width, guidance_scale=args.save_guidance_scale, num_inference_steps=args.save_infer_steps, generator=generator).images
1283
+
1284
+ if not args.flatten_sample_folder:
1285
+ images[0].save(os.path.join(sample_dir,sampleName, f"{sampleName}_controlled_seed_{str(seed)}.png"))
1286
+ else:
1287
+ images[0].save(os.path.join(sample_dir, f"{sampleName}_controlled_seed_{str(seed)}.png"))
1288
+ progress_bar_sample.update(1)
1289
+
1290
+ if args.send_telegram_updates:
1291
+ imgs = []
1292
+ #get all the images from the sample folder
1293
+ if not args.flatten_sample_folder:
1294
+ dir = os.listdir(os.path.join(sample_dir,sampleName))
1295
+ else:
1296
+ dir = sample_dir
1297
+
1298
+ for file in dir:
1299
+ if file.endswith(".png"):
1300
+ #open the image with pil
1301
+ img = Image.open(os.path.join(sample_dir,sampleName,file))
1302
+ imgs.append(img)
1303
+ try:
1304
+ send_media_group(args.telegram_chat_id,args.telegram_token,imgs, caption=f"Samples for the <b>{step}</b> {context} using the prompt:\n\n<b>{samplePrompt}</b>")
1305
+ except:
1306
+ pass
1307
+ del pipeline
1308
+ del unwrapped_unet
1309
+ for param in unet.parameters():
1310
+ param.requires_grad = True
1311
+ if torch.cuda.is_available():
1312
+ torch.cuda.empty_cache()
1313
+ torch.cuda.ipc_collect()
1314
+ if save_model == True:
1315
+ tqdm.write(f"{bcolors.OKGREEN}Weights saved to {save_dir}{bcolors.ENDC}")
1316
+ elif save_model == False and len(imgs) > 0:
1317
+ del imgs
1318
+ tqdm.write(f"{bcolors.OKGREEN}Samples saved to {sample_dir}{bcolors.ENDC}")
1319
+
1320
+ except Exception as e:
1321
+ tqdm.write(e)
1322
+ tqdm.write(f"{bcolors.FAIL} Error occured during sampling, skipping.{bcolors.ENDC}")
1323
+ pass
1324
+
1325
+ @torch.no_grad()
1326
+ def update_ema(ema_model, model):
1327
+ ema_step = ema_model.config["step"]
1328
+ decay = min((ema_step + 1) / (ema_step + 10), 0.9999)
1329
+ ema_model.config["step"] += 1
1330
+ for (s_param, param) in zip(ema_model.parameters(), model.parameters()):
1331
+ if param.requires_grad:
1332
+ s_param.add_((1 - decay) * (param - s_param))
1333
+ else:
1334
+ s_param.copy_(param)
1335
+
1336
+
1337
+ # Only show the progress bar once on each machine.
1338
+ progress_bar = tqdm(range(args.max_train_steps),bar_format='%s{l_bar}%s%s{bar}%s%s{r_bar}%s'%(bcolors.OKBLUE,bcolors.ENDC, bcolors.OKBLUE, bcolors.ENDC,bcolors.OKBLUE,bcolors.ENDC,), disable=not accelerator.is_local_main_process)
1339
+ progress_bar_inter_epoch = tqdm(range(num_update_steps_per_epoch),bar_format='%s{l_bar}%s%s{bar}%s%s{r_bar}%s'%(bcolors.OKBLUE,bcolors.ENDC, bcolors.OKGREEN, bcolors.ENDC,bcolors.OKBLUE,bcolors.ENDC,), disable=not accelerator.is_local_main_process)
1340
+ progress_bar_e = tqdm(range(args.num_train_epochs),bar_format='%s{l_bar}%s%s{bar}%s%s{r_bar}%s'%(bcolors.OKBLUE,bcolors.ENDC, bcolors.OKGREEN, bcolors.ENDC,bcolors.OKBLUE,bcolors.ENDC,), disable=not accelerator.is_local_main_process)
1341
+
1342
+ progress_bar.set_description("Overall Steps")
1343
+ progress_bar_inter_epoch.set_description("Steps To Epoch")
1344
+ progress_bar_e.set_description("Overall Epochs")
1345
+ global_step = 0
1346
+ loss_avg = AverageMeter("loss_avg", max_eta=0.999)
1347
+ gan_loss_avg = AverageMeter("gan_loss_avg", max_eta=0.999)
1348
+ text_enc_context = nullcontext() if args.train_text_encoder else torch.no_grad()
1349
+ if args.send_telegram_updates:
1350
+ try:
1351
+ send_telegram_message(f"Starting training with the following settings:\n\n{format_dict(args.__dict__)}", args.telegram_chat_id, args.telegram_token)
1352
+ except:
1353
+ pass
1354
+ try:
1355
+ tqdm.write(f"{bcolors.OKBLUE}Starting Training!{bcolors.ENDC}")
1356
+ try:
1357
+ def toggle_gui(event=None):
1358
+ if keyboard.is_pressed("ctrl") and keyboard.is_pressed("shift") and keyboard.is_pressed("g"):
1359
+ tqdm.write(f"{bcolors.WARNING}GUI will boot as soon as the current step is done.{bcolors.ENDC}")
1360
+ nonlocal mid_generation
1361
+ if mid_generation == True:
1362
+ mid_generation = False
1363
+ tqdm.write(f"{bcolors.WARNING}Cancelled GUI.{bcolors.ENDC}")
1364
+ else:
1365
+ mid_generation = True
1366
+
1367
+ def toggle_checkpoint(event=None):
1368
+ if keyboard.is_pressed("ctrl") and keyboard.is_pressed("shift") and keyboard.is_pressed("s") and not keyboard.is_pressed("alt"):
1369
+ tqdm.write(f"{bcolors.WARNING}Saving the model as soon as this epoch is done.{bcolors.ENDC}")
1370
+ nonlocal mid_checkpoint
1371
+ if mid_checkpoint == True:
1372
+ mid_checkpoint = False
1373
+ tqdm.write(f"{bcolors.WARNING}Cancelled Checkpointing.{bcolors.ENDC}")
1374
+ else:
1375
+ mid_checkpoint = True
1376
+
1377
+ def toggle_sample(event=None):
1378
+ if keyboard.is_pressed("ctrl") and keyboard.is_pressed("shift") and keyboard.is_pressed("p") and not keyboard.is_pressed("alt"):
1379
+ tqdm.write(f"{bcolors.WARNING}Sampling will begin as soon as this epoch is done.{bcolors.ENDC}")
1380
+ nonlocal mid_sample
1381
+ if mid_sample == True:
1382
+ mid_sample = False
1383
+ tqdm.write(f"{bcolors.WARNING}Cancelled Sampling.{bcolors.ENDC}")
1384
+ else:
1385
+ mid_sample = True
1386
+ def toggle_checkpoint_step(event=None):
1387
+ if keyboard.is_pressed("ctrl") and keyboard.is_pressed("shift") and keyboard.is_pressed("alt") and keyboard.is_pressed("s"):
1388
+ tqdm.write(f"{bcolors.WARNING}Saving the model as soon as this step is done.{bcolors.ENDC}")
1389
+ nonlocal mid_checkpoint_step
1390
+ if mid_checkpoint_step == True:
1391
+ mid_checkpoint_step = False
1392
+ tqdm.write(f"{bcolors.WARNING}Cancelled Checkpointing.{bcolors.ENDC}")
1393
+ else:
1394
+ mid_checkpoint_step = True
1395
+
1396
+ def toggle_sample_step(event=None):
1397
+ if keyboard.is_pressed("ctrl") and keyboard.is_pressed("shift") and keyboard.is_pressed("alt") and keyboard.is_pressed("p"):
1398
+ tqdm.write(f"{bcolors.WARNING}Sampling will begin as soon as this step is done.{bcolors.ENDC}")
1399
+ nonlocal mid_sample_step
1400
+ if mid_sample_step == True:
1401
+ mid_sample_step = False
1402
+ tqdm.write(f"{bcolors.WARNING}Cancelled Sampling.{bcolors.ENDC}")
1403
+ else:
1404
+ mid_sample_step = True
1405
+ def toggle_quit_and_save_epoch(event=None):
1406
+ if keyboard.is_pressed("ctrl") and keyboard.is_pressed("shift") and keyboard.is_pressed("q") and not keyboard.is_pressed("alt"):
1407
+ tqdm.write(f"{bcolors.WARNING}Quitting and saving the model as soon as this epoch is done.{bcolors.ENDC}")
1408
+ nonlocal mid_quit
1409
+ if mid_quit == True:
1410
+ mid_quit = False
1411
+ tqdm.write(f"{bcolors.WARNING}Cancelled Quitting.{bcolors.ENDC}")
1412
+ else:
1413
+ mid_quit = True
1414
+ def toggle_quit_and_save_step(event=None):
1415
+ if keyboard.is_pressed("ctrl") and keyboard.is_pressed("shift") and keyboard.is_pressed("alt") and keyboard.is_pressed("q"):
1416
+ tqdm.write(f"{bcolors.WARNING}Quitting and saving the model as soon as this step is done.{bcolors.ENDC}")
1417
+ nonlocal mid_quit_step
1418
+ if mid_quit_step == True:
1419
+ mid_quit_step = False
1420
+ tqdm.write(f"{bcolors.WARNING}Cancelled Quitting.{bcolors.ENDC}")
1421
+ else:
1422
+ mid_quit_step = True
1423
+ def help(event=None):
1424
+ if keyboard.is_pressed("ctrl") and keyboard.is_pressed("h"):
1425
+ print_instructions()
1426
+ keyboard.on_press_key("g", toggle_gui)
1427
+ keyboard.on_press_key("s", toggle_checkpoint)
1428
+ keyboard.on_press_key("p", toggle_sample)
1429
+ keyboard.on_press_key("s", toggle_checkpoint_step)
1430
+ keyboard.on_press_key("p", toggle_sample_step)
1431
+ keyboard.on_press_key("q", toggle_quit_and_save_epoch)
1432
+ keyboard.on_press_key("q", toggle_quit_and_save_step)
1433
+ keyboard.on_press_key("h", help)
1434
+ print_instructions()
1435
+ except Exception as e:
1436
+ pass
1437
+
1438
+ mid_generation = False
1439
+ mid_checkpoint = False
1440
+ mid_sample = False
1441
+ mid_checkpoint_step = False
1442
+ mid_sample_step = False
1443
+ mid_quit = False
1444
+ mid_quit_step = False
1445
+ #lambda set mid_generation to true
1446
+ if args.run_name:
1447
+ frozen_directory = os.path.join(args.output_dir, f"frozen_text_encoder_{args.run_name}")
1448
+ else:
1449
+ frozen_directory = os.path.join(args.output_dir, "frozen_text_encoder")
1450
+
1451
+ unet_stats = {}
1452
+ discriminator_stats = {}
1453
+
1454
+ os.makedirs(main_sample_dir, exist_ok=True)
1455
+ with open(os.path.join(main_sample_dir, "args.json"), "w") as f:
1456
+ json.dump(args.__dict__, f, indent=2)
1457
+ if args.with_gan:
1458
+ with open(os.path.join(main_sample_dir, "discriminator_config.json"), "w") as f:
1459
+ json.dump(discriminator.config, f, indent=2)
1460
+
1461
+ for epoch in range(args.num_train_epochs):
1462
+ #every 10 epochs print instructions
1463
+ unet.train()
1464
+ if args.train_text_encoder:
1465
+ text_encoder.train()
1466
+
1467
+ #save initial weights
1468
+ if args.sample_on_training_start==True and epoch==0:
1469
+ save_and_sample_weights(epoch,'start',save_model=False)
1470
+
1471
+ if args.train_text_encoder and args.stop_text_encoder_training == epoch:
1472
+ args.stop_text_encoder_training = True
1473
+ if accelerator.is_main_process:
1474
+ tqdm.write(f"{bcolors.WARNING} Stopping text encoder training{bcolors.ENDC}")
1475
+ current_percentage = (epoch/args.num_train_epochs)*100
1476
+ #round to the nearest whole number
1477
+ current_percentage = round(current_percentage,0)
1478
+ try:
1479
+ send_telegram_message(f"Text encoder training stopped at epoch {epoch} which is {current_percentage}% of training. Freezing weights and saving.", args.telegram_chat_id, args.telegram_token)
1480
+ except:
1481
+ pass
1482
+ if os.path.exists(frozen_directory):
1483
+ #delete the folder if it already exists
1484
+ shutil.rmtree(frozen_directory)
1485
+ os.mkdir(frozen_directory)
1486
+ save_and_sample_weights(epoch,'epoch')
1487
+ args.stop_text_encoder_training = epoch
1488
+ progress_bar_inter_epoch.reset(total=num_update_steps_per_epoch)
1489
+ for step, batch in enumerate(train_dataloader):
1490
+ with accelerator.accumulate(unet):
1491
+ # Convert images to latent space
1492
+ with torch.no_grad():
1493
+
1494
+ latent_dist = batch[0][0]
1495
+ latents = latent_dist.sample() * 0.18215
1496
+
1497
+ if args.model_variant == 'inpainting':
1498
+ mask = batch[0][2]
1499
+ mask_mean = batch[0][3]
1500
+ conditioning_latent_dist = batch[0][4]
1501
+ conditioning_latents = conditioning_latent_dist.sample() * 0.18215
1502
+ if args.model_variant == 'depth2img':
1503
+ depth = batch[0][4]
1504
+ if args.sample_from_batch > 0:
1505
+ args.batch_tokens = batch[0][5]
1506
+ # Sample noise that we'll add to the latents
1507
+ # and some extra bits to make it so that the model learns to change the zero-frequency of the component freely
1508
+ # https://www.crosslabs.org/blog/diffusion-with-offset-noise
1509
+ if (args.with_offset_noise == True):
1510
+ noise = torch.randn_like(latents) + (args.offset_noise_weight * torch.randn(latents.shape[0], latents.shape[1], 1, 1).to(accelerator.device))
1511
+ else:
1512
+ noise = torch.randn_like(latents)
1513
+
1514
+ bsz = latents.shape[0]
1515
+ # Sample a random timestep for each image
1516
+ timesteps = torch.randint(0, int(noise_scheduler.config.num_train_timesteps * args.max_denoising_strength), (bsz,), device=latents.device)
1517
+ timesteps = timesteps.long()
1518
+
1519
+ # Add noise to the latents according to the noise magnitude at each timestep
1520
+ # (this is the forward diffusion process)
1521
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1522
+
1523
+ # Get the text embedding for conditioning
1524
+ with text_enc_context:
1525
+ if args.train_text_encoder:
1526
+ if args.clip_penultimate == True:
1527
+ encoder_hidden_states = text_encoder(batch[0][1],output_hidden_states=True)
1528
+ encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2])
1529
+ else:
1530
+ encoder_hidden_states = text_encoder(batch[0][1])[0]
1531
+ else:
1532
+ encoder_hidden_states = batch[0][1]
1533
+
1534
+
1535
+ # Predict the noise residual
1536
+ mask=None
1537
+ if args.model_variant == 'inpainting':
1538
+ if mask is not None and random.uniform(0, 1) < args.unmasked_probability:
1539
+ # for some steps, predict the unmasked image
1540
+ conditioning_latents = torch.stack([full_mask_by_aspect[tuple([latents.shape[3]*8, latents.shape[2]*8])].squeeze()] * bsz)
1541
+ mask = torch.ones(bsz, 1, latents.shape[2], latents.shape[3]).to(accelerator.device, dtype=weight_dtype)
1542
+ noisy_inpaint_latents = torch.concat([noisy_latents, mask, conditioning_latents], 1)
1543
+ model_pred = unet(noisy_inpaint_latents, timesteps, encoder_hidden_states).sample
1544
+ elif args.model_variant == 'depth2img':
1545
+ noisy_depth_latents = torch.cat([noisy_latents, depth], dim=1)
1546
+ model_pred = unet(noisy_depth_latents, timesteps, encoder_hidden_states, depth).sample
1547
+ elif args.model_variant == "base":
1548
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
1549
+
1550
+
1551
+ # Get the target for loss depending on the prediction type
1552
+ if noise_scheduler.config.prediction_type == "epsilon":
1553
+ target = noise
1554
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1555
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
1556
+ else:
1557
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1558
+
1559
+ # GAN stuff
1560
+ # Input: noisy_latents
1561
+ # True output: target
1562
+ # Fake output: model_pred
1563
+
1564
+ if args.with_gan:
1565
+ # Turn on learning for the discriminator, and do an optimization step
1566
+ for param in discriminator.parameters():
1567
+ param.requires_grad = True
1568
+
1569
+ pred_fake = discriminator(torch.cat((noisy_latents, model_pred), 1).detach(), encoder_hidden_states)
1570
+ pred_real = discriminator(torch.cat((noisy_latents, target), 1), encoder_hidden_states)
1571
+ discriminator_loss = F.mse_loss(pred_fake, torch.zeros_like(pred_fake), reduction="mean") + F.mse_loss(pred_real, torch.ones_like(pred_real), reduction="mean")
1572
+ if discriminator_loss.isnan():
1573
+ tqdm.write(f"{bcolors.WARNING}Discriminator loss is NAN, skipping GAN update.{bcolors.ENDC}")
1574
+ else:
1575
+ accelerator.backward(discriminator_loss)
1576
+ if accelerator.sync_gradients:
1577
+ accelerator.clip_grad_norm_(discriminator.parameters(), args.max_grad_norm)
1578
+ optimizer_discriminator.step()
1579
+ lr_scheduler_discriminator.step()
1580
+ # Hack to fix NaNs caused by GAN training
1581
+ for name, p in discriminator.named_parameters():
1582
+ if p.isnan().any():
1583
+ fix_nans_(p, name, discriminator_stats[name])
1584
+ else:
1585
+ (std, mean) = torch.std_mean(p)
1586
+ discriminator_stats[name] = (std.item(), mean.item())
1587
+ del std, mean
1588
+ optimizer_discriminator.zero_grad()
1589
+ del pred_real, pred_fake, discriminator_loss
1590
+
1591
+ # Turn off learning for the discriminator for the generator optimization step
1592
+ for param in discriminator.parameters():
1593
+ param.requires_grad = False
1594
+
1595
+ if args.with_prior_preservation:
1596
+ # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
1597
+ """
1598
+ noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
1599
+ noise, noise_prior = torch.chunk(noise, 2, dim=0)
1600
+
1601
+ # Compute instance loss
1602
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
1603
+
1604
+ # Compute prior loss
1605
+ prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
1606
+
1607
+ # Add the prior loss to the instance loss.
1608
+ loss = loss + args.prior_loss_weight * prior_loss
1609
+ """
1610
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
1611
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
1612
+ target, target_prior = torch.chunk(target, 2, dim=0)
1613
+ if mask is not None and args.model_variant != "inpainting":
1614
+ loss = masked_mse_loss(model_pred.float(), target.float(), mask, reduction="none").mean([1, 2, 3]).mean()
1615
+ prior_loss = masked_mse_loss(model_pred_prior.float(), target_prior.float(), mask, reduction="mean")
1616
+ else:
1617
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
1618
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
1619
+
1620
+ # Add the prior loss to the instance loss.
1621
+ loss = loss + args.prior_loss_weight * prior_loss
1622
+
1623
+ if mask is not None and args.normalize_masked_area_loss:
1624
+ loss = loss / mask_mean
1625
+
1626
+ else:
1627
+ if mask is not None and args.model_variant != "inpainting":
1628
+ loss = masked_mse_loss(model_pred.float(), target.float(), mask, reduction="none").mean([1, 2, 3])
1629
+ loss = loss.mean()
1630
+ else:
1631
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1632
+
1633
+ if mask is not None and args.normalize_masked_area_loss:
1634
+ loss = loss / mask_mean
1635
+
1636
+ base_loss = loss
1637
+
1638
+ if args.with_gan:
1639
+ # Add loss from the GAN
1640
+ pred_fake = discriminator(torch.cat((noisy_latents, model_pred), 1), encoder_hidden_states)
1641
+ gan_loss = F.mse_loss(pred_fake, torch.ones_like(pred_fake), reduction="mean")
1642
+ if gan_loss.isnan():
1643
+ tqdm.write(f"{bcolors.WARNING}GAN loss is NAN, skipping GAN loss.{bcolors.ENDC}")
1644
+ else:
1645
+ gan_weight = args.gan_weight
1646
+ if args.gan_warmup and global_step < args.gan_warmup:
1647
+ gan_weight *= global_step / args.gan_warmup
1648
+ loss += gan_weight * gan_loss
1649
+ del pred_fake
1650
+
1651
+ accelerator.backward(loss)
1652
+ if accelerator.sync_gradients:
1653
+ params_to_clip = (
1654
+ itertools.chain(unet.parameters(), text_encoder.parameters())
1655
+ if args.train_text_encoder
1656
+ else unet.parameters()
1657
+ )
1658
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1659
+ optimizer.step()
1660
+ lr_scheduler.step()
1661
+ # Hack to fix NaNs caused by GAN training
1662
+ for name, p in unet.named_parameters():
1663
+ if p.isnan().any():
1664
+ fix_nans_(p, name, unet_stats[name])
1665
+ else:
1666
+ (std, mean) = torch.std_mean(p)
1667
+ unet_stats[name] = (std.item(), mean.item())
1668
+ del std, mean
1669
+ optimizer.zero_grad()
1670
+ loss_avg.update(base_loss.detach_())
1671
+ if args.with_gan and not gan_loss.isnan():
1672
+ gan_loss_avg.update(gan_loss.detach_())
1673
+ if args.use_ema == True:
1674
+ update_ema(ema_unet, unet)
1675
+
1676
+ del loss, model_pred
1677
+ if args.with_prior_preservation:
1678
+ del model_pred_prior
1679
+
1680
+ logs = {"loss": loss_avg.avg.item(), "lr": lr_scheduler.get_last_lr()[0]}
1681
+ if args.with_gan:
1682
+ logs["gan_loss"] = gan_loss_avg.avg.item()
1683
+ progress_bar.set_postfix(**logs)
1684
+ if not global_step % args.log_interval:
1685
+ accelerator.log(logs, step=global_step)
1686
+
1687
+
1688
+
1689
+ if global_step > 0 and not global_step % args.sample_step_interval:
1690
+ save_and_sample_weights(global_step,'step',save_model=False)
1691
+
1692
+ progress_bar.update(1)
1693
+ progress_bar_inter_epoch.update(1)
1694
+ progress_bar_e.refresh()
1695
+ global_step += 1
1696
+
1697
+ if mid_quit_step==True:
1698
+ accelerator.wait_for_everyone()
1699
+ save_and_sample_weights(global_step,'quit_step')
1700
+ quit()
1701
+ if mid_generation==True:
1702
+ mid_train_playground(global_step)
1703
+ mid_generation=False
1704
+ if mid_checkpoint_step == True:
1705
+ save_and_sample_weights(global_step,'step',save_model=True)
1706
+ mid_checkpoint_step=False
1707
+ mid_sample_step=False
1708
+ elif mid_sample_step == True:
1709
+ save_and_sample_weights(global_step,'step',save_model=False)
1710
+ mid_sample_step=False
1711
+ if global_step >= args.max_train_steps:
1712
+ break
1713
+ progress_bar_e.update(1)
1714
+ if mid_quit==True:
1715
+ accelerator.wait_for_everyone()
1716
+ save_and_sample_weights(epoch,'quit_epoch')
1717
+ quit()
1718
+ if epoch == args.num_train_epochs - 1:
1719
+ save_and_sample_weights(epoch,'epoch',True)
1720
+ elif args.save_every_n_epoch and (epoch + 1) % args.save_every_n_epoch == 0:
1721
+ save_and_sample_weights(epoch,'epoch',True)
1722
+ elif mid_checkpoint==True:
1723
+ save_and_sample_weights(epoch,'epoch',True)
1724
+ mid_checkpoint=False
1725
+ mid_sample=False
1726
+ elif mid_sample==True:
1727
+ save_and_sample_weights(epoch,'epoch',False)
1728
+ mid_sample=False
1729
+ accelerator.wait_for_everyone()
1730
+ except Exception:
1731
+ try:
1732
+ send_telegram_message("Something went wrong while training! :(", args.telegram_chat_id, args.telegram_token)
1733
+ #save_and_sample_weights(global_step,'checkpoint')
1734
+ send_telegram_message(f"Saved checkpoint {global_step} on exit", args.telegram_chat_id, args.telegram_token)
1735
+ except Exception:
1736
+ pass
1737
+ raise
1738
+ except KeyboardInterrupt:
1739
+ send_telegram_message("Training stopped", args.telegram_chat_id, args.telegram_token)
1740
+ try:
1741
+ send_telegram_message("Training finished!", args.telegram_chat_id, args.telegram_token)
1742
+ except:
1743
+ pass
1744
+
1745
+ accelerator.end_training()
1746
+
1747
+
1748
+
1749
+ if __name__ == "__main__":
1750
+ main()
StableTuner_RunPod_Fix/trainer_util.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import math
4
+ from pathlib import Path
5
+ from typing import Optional
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint
9
+ from accelerate.logging import get_logger
10
+ from accelerate.utils import set_seed
11
+ from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel,DiffusionPipeline, DPMSolverMultistepScheduler,EulerDiscreteScheduler
12
+ from diffusers.optimization import get_scheduler
13
+ from huggingface_hub import HfFolder, Repository, whoami
14
+ from torchvision import transforms
15
+ from tqdm.auto import tqdm
16
+ from typing import Dict, List, Generator, Tuple
17
+ from PIL import Image, ImageFile
18
+ from collections.abc import Iterable
19
+ from trainer_util import *
20
+ from dataloaders_util import *
21
+
22
+ # FlashAttention based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main
23
+ # /memory_efficient_attention_pytorch/flash_attention.py LICENSE MIT
24
+ # https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE constants
25
+ EPSILON = 1e-6
26
+
27
+ class bcolors:
28
+ HEADER = '\033[95m'
29
+ OKBLUE = '\033[94m'
30
+ OKCYAN = '\033[96m'
31
+ OKGREEN = '\033[92m'
32
+ WARNING = '\033[93m'
33
+ FAIL = '\033[91m'
34
+ ENDC = '\033[0m'
35
+ BOLD = '\033[1m'
36
+ UNDERLINE = '\033[4m'
37
+ # helper functions
38
+ def print_instructions():
39
+ tqdm.write(f"{bcolors.WARNING}Use 'CTRL+SHIFT+G' to open up a GUI to play around with the model (will pause training){bcolors.ENDC}")
40
+ tqdm.write(f"{bcolors.WARNING}Use 'CTRL+SHIFT+S' to save a checkpoint of the current epoch{bcolors.ENDC}")
41
+ tqdm.write(f"{bcolors.WARNING}Use 'CTRL+SHIFT+P' to generate samples for current epoch{bcolors.ENDC}")
42
+ tqdm.write(f"{bcolors.WARNING}Use 'CTRL+SHIFT+Q' to save and quit after the current epoch{bcolors.ENDC}")
43
+ tqdm.write(f"{bcolors.WARNING}Use 'CTRL+SHIFT+ALT+S' to save a checkpoint of the current step{bcolors.ENDC}")
44
+ tqdm.write(f"{bcolors.WARNING}Use 'CTRL+SHIFT+ALT+P' to generate samples for current step{bcolors.ENDC}")
45
+ tqdm.write(f"{bcolors.WARNING}Use 'CTRL+SHIFT+ALT+Q' to save and quit after the current step{bcolors.ENDC}")
46
+ tqdm.write('')
47
+ tqdm.write(f"{bcolors.WARNING}Use 'CTRL+H' to print this message again.{bcolors.ENDC}")
48
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
49
+ if token is None:
50
+ token = HfFolder.get_token()
51
+ if organization is None:
52
+ username = whoami(token)["name"]
53
+ return f"{username}/{model_id}"
54
+ else:
55
+ return f"{organization}/{model_id}"
56
+
57
+ #function to format a dictionary into a telegram message
58
+ def format_dict(d):
59
+ message = ""
60
+ for key, value in d.items():
61
+ #filter keys that have the word "token" in them
62
+ if "token" in key and "tokenizer" not in key:
63
+ value = "TOKEN"
64
+ if 'id' in key:
65
+ value = "ID"
66
+ #if value is a dictionary, format it recursively
67
+ if isinstance(value, dict):
68
+ for k, v in value.items():
69
+ message += f"\n- {k}: <b>{v}</b> \n"
70
+ elif isinstance(value, list):
71
+ #each value is a new line in the message
72
+ message += f"- {key}:\n\n"
73
+ for v in value:
74
+ message += f" <b>{v}</b>\n\n"
75
+ #if value is a list, format it as a list
76
+ else:
77
+ message += f"- {key}: <b>{value}</b>\n"
78
+ return message
79
+
80
+ def send_telegram_message(message, chat_id, token):
81
+ url = f"https://api.telegram.org/bot{token}/sendMessage?chat_id={chat_id}&text={message}&parse_mode=html&disable_notification=True"
82
+ import requests
83
+ req = requests.get(url)
84
+ if req.status_code != 200:
85
+ raise ValueError(f"Telegram request failed with status code {req.status_code}")
86
+ def send_media_group(chat_id,telegram_token, images, caption=None, reply_to_message_id=None):
87
+ """
88
+ Use this method to send an album of photos. On success, an array of Messages that were sent is returned.
89
+ :param chat_id: chat id
90
+ :param images: list of PIL images to send
91
+ :param caption: caption of image
92
+ :param reply_to_message_id: If the message is a reply, ID of the original message
93
+ :return: response with the sent message
94
+ """
95
+ SEND_MEDIA_GROUP = f'https://api.telegram.org/bot{telegram_token}/sendMediaGroup'
96
+ from io import BytesIO
97
+ import requests
98
+ files = {}
99
+ media = []
100
+ for i, img in enumerate(images):
101
+ with BytesIO() as output:
102
+ img.save(output, format='PNG')
103
+ output.seek(0)
104
+ name = f'photo{i}'
105
+ files[name] = output.read()
106
+ # a list of InputMediaPhoto. attach refers to the name of the file in the files dict
107
+ media.append(dict(type='photo', media=f'attach://{name}'))
108
+ media[0]['caption'] = caption
109
+ media[0]['parse_mode'] = 'HTML'
110
+ return requests.post(SEND_MEDIA_GROUP, data={'chat_id': chat_id, 'media': json.dumps(media),'disable_notification':True, 'reply_to_message_id': reply_to_message_id }, files=files)
111
+ class AverageMeter:
112
+ def __init__(self, name=None, max_eta=None):
113
+ self.name = name
114
+ self.max_eta = max_eta
115
+ self.reset()
116
+
117
+ def reset(self):
118
+ self.count = self.avg = 0
119
+
120
+ @torch.no_grad()
121
+ def update(self, val, n=1):
122
+ eta = self.count / (self.count + n)
123
+ if self.max_eta:
124
+ eta = min(eta, self.max_eta ** n)
125
+ self.avg += (1 - eta) * (val - self.avg)
126
+ self.count += n
127
+
128
+ def exists(val):
129
+ return val is not None
130
+
131
+
132
+ def default(val, d):
133
+ return val if exists(val) else d
134
+
135
+
136
+ def masked_mse_loss(predicted, target, mask, reduction="none"):
137
+ masked_predicted = predicted * mask
138
+ masked_target = target * mask
139
+ return F.mse_loss(masked_predicted, masked_target, reduction=reduction)
140
+
141
+ # flash attention forwards and backwards
142
+ # https://arxiv.org/abs/2205.14135
143
+
144
+
145
+ class FlashAttentionFunction(torch.autograd.function.Function):
146
+ @staticmethod
147
+ @torch.no_grad()
148
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
149
+ """ Algorithm 2 in the paper """
150
+
151
+ device = q.device
152
+ dtype = q.dtype
153
+ max_neg_value = -torch.finfo(q.dtype).max
154
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
155
+
156
+ o = torch.zeros_like(q)
157
+ all_row_sums = torch.zeros(
158
+ (*q.shape[:-1], 1), dtype=dtype, device=device)
159
+ all_row_maxes = torch.full(
160
+ (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
161
+
162
+ scale = (q.shape[-1] ** -0.5)
163
+
164
+ if not exists(mask):
165
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
166
+ else:
167
+ mask = rearrange(mask, 'b n -> b 1 1 n')
168
+ mask = mask.split(q_bucket_size, dim=-1)
169
+
170
+ row_splits = zip(
171
+ q.split(q_bucket_size, dim=-2),
172
+ o.split(q_bucket_size, dim=-2),
173
+ mask,
174
+ all_row_sums.split(q_bucket_size, dim=-2),
175
+ all_row_maxes.split(q_bucket_size, dim=-2),
176
+ )
177
+
178
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
179
+ q_start_index = ind * q_bucket_size - qk_len_diff
180
+
181
+ col_splits = zip(
182
+ k.split(k_bucket_size, dim=-2),
183
+ v.split(k_bucket_size, dim=-2),
184
+ )
185
+
186
+ for k_ind, (kc, vc) in enumerate(col_splits):
187
+ k_start_index = k_ind * k_bucket_size
188
+
189
+ attn_weights = einsum(
190
+ '... i d, ... j d -> ... i j', qc, kc) * scale
191
+
192
+ if exists(row_mask):
193
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
194
+
195
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
196
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
197
+ device=device).triu(q_start_index - k_start_index + 1)
198
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
199
+
200
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
201
+ attn_weights -= block_row_maxes
202
+ exp_weights = torch.exp(attn_weights)
203
+
204
+ if exists(row_mask):
205
+ exp_weights.masked_fill_(~row_mask, 0.)
206
+
207
+ block_row_sums = exp_weights.sum(
208
+ dim=-1, keepdims=True).clamp(min=EPSILON)
209
+
210
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
211
+
212
+ exp_values = einsum(
213
+ '... i j, ... j d -> ... i d', exp_weights, vc)
214
+
215
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
216
+ exp_block_row_max_diff = torch.exp(
217
+ block_row_maxes - new_row_maxes)
218
+
219
+ new_row_sums = exp_row_max_diff * row_sums + \
220
+ exp_block_row_max_diff * block_row_sums
221
+
222
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
223
+ (exp_block_row_max_diff / new_row_sums) * exp_values)
224
+
225
+ row_maxes.copy_(new_row_maxes)
226
+ row_sums.copy_(new_row_sums)
227
+
228
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
229
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
230
+
231
+ return o
232
+
233
+ @staticmethod
234
+ @torch.no_grad()
235
+ def backward(ctx, do):
236
+ """ Algorithm 4 in the paper """
237
+
238
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
239
+ q, k, v, o, l, m = ctx.saved_tensors
240
+
241
+ device = q.device
242
+
243
+ max_neg_value = -torch.finfo(q.dtype).max
244
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
245
+
246
+ dq = torch.zeros_like(q)
247
+ dk = torch.zeros_like(k)
248
+ dv = torch.zeros_like(v)
249
+
250
+ row_splits = zip(
251
+ q.split(q_bucket_size, dim=-2),
252
+ o.split(q_bucket_size, dim=-2),
253
+ do.split(q_bucket_size, dim=-2),
254
+ mask,
255
+ l.split(q_bucket_size, dim=-2),
256
+ m.split(q_bucket_size, dim=-2),
257
+ dq.split(q_bucket_size, dim=-2)
258
+ )
259
+
260
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
261
+ q_start_index = ind * q_bucket_size - qk_len_diff
262
+
263
+ col_splits = zip(
264
+ k.split(k_bucket_size, dim=-2),
265
+ v.split(k_bucket_size, dim=-2),
266
+ dk.split(k_bucket_size, dim=-2),
267
+ dv.split(k_bucket_size, dim=-2),
268
+ )
269
+
270
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
271
+ k_start_index = k_ind * k_bucket_size
272
+
273
+ attn_weights = einsum(
274
+ '... i d, ... j d -> ... i j', qc, kc) * scale
275
+
276
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
277
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
278
+ device=device).triu(q_start_index - k_start_index + 1)
279
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
280
+
281
+ exp_attn_weights = torch.exp(attn_weights - mc)
282
+
283
+ if exists(row_mask):
284
+ exp_attn_weights.masked_fill_(~row_mask, 0.)
285
+
286
+ p = exp_attn_weights / lc
287
+
288
+ dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
289
+ dp = einsum('... i d, ... j d -> ... i j', doc, vc)
290
+
291
+ D = (doc * oc).sum(dim=-1, keepdims=True)
292
+ ds = p * scale * (dp - D)
293
+
294
+ dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
295
+ dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
296
+
297
+ dqc.add_(dq_chunk)
298
+ dkc.add_(dk_chunk)
299
+ dvc.add_(dv_chunk)
300
+
301
+ return dq, dk, dv, None, None, None, None
302
+
303
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
304
+ text_encoder_config = PretrainedConfig.from_pretrained(
305
+ pretrained_model_name_or_path,
306
+ subfolder="text_encoder",
307
+ revision=revision,
308
+ )
309
+ model_class = text_encoder_config.architectures[0]
310
+
311
+ if model_class == "CLIPTextModel":
312
+ from transformers import CLIPTextModel
313
+
314
+ return CLIPTextModel
315
+ elif model_class == "RobertaSeriesModelWithTransformation":
316
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
317
+
318
+ return RobertaSeriesModelWithTransformation
319
+ else:
320
+ raise ValueError(f"{model_class} is not supported.")
321
+
322
+ def replace_unet_cross_attn_to_flash_attention():
323
+ print("Using FlashAttention")
324
+
325
+ def forward_flash_attn(self, x, context=None, mask=None):
326
+ q_bucket_size = 512
327
+ k_bucket_size = 1024
328
+
329
+ h = self.heads
330
+ q = self.to_q(x)
331
+
332
+ context = context if context is not None else x
333
+ context = context.to(x.dtype)
334
+
335
+ if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
336
+ context_k, context_v = self.hypernetwork.forward(x, context)
337
+ context_k = context_k.to(x.dtype)
338
+ context_v = context_v.to(x.dtype)
339
+ else:
340
+ context_k = context
341
+ context_v = context
342
+
343
+ k = self.to_k(context_k)
344
+ v = self.to_v(context_v)
345
+ del context, x
346
+
347
+ q, k, v = map(lambda t: rearrange(
348
+ t, 'b n (h d) -> b h n d', h=h), (q, k, v))
349
+
350
+ out = FlashAttentionFunction.apply(q, k, v, mask, False,
351
+ q_bucket_size, k_bucket_size)
352
+
353
+ out = rearrange(out, 'b h n d -> b n (h d)')
354
+
355
+ # diffusers 0.6.0
356
+ if type(self.to_out) is torch.nn.Sequential:
357
+ return self.to_out(out)
358
+
359
+ # diffusers 0.7.0
360
+ out = self.to_out[0](out)
361
+ out = self.to_out[1](out)
362
+ return out
363
+
364
+ diffusers.models.attention.CrossAttention.forward = forward_flash_attn
365
+ class Depth2Img:
366
+ def __init__(self,unet,text_encoder,revision,pretrained_model_name_or_path,accelerator):
367
+ self.unet = unet
368
+ self.text_encoder = text_encoder
369
+ self.revision = revision if revision != 'no' else 'fp32'
370
+ self.pretrained_model_name_or_path = pretrained_model_name_or_path
371
+ self.accelerator = accelerator
372
+ self.pipeline = None
373
+ def depth_images(self,paths):
374
+ if self.pipeline is None:
375
+ self.pipeline = DiffusionPipeline.from_pretrained(
376
+ self.pretrained_model_name_or_path,
377
+ unet=self.accelerator.unwrap_model(self.unet),
378
+ text_encoder=self.accelerator.unwrap_model(self.text_encoder),
379
+ revision=self.revision,
380
+ local_files_only=True,)
381
+ self.pipeline.to(self.accelerator.device)
382
+ self.vae_scale_factor = 2 ** (len(self.pipeline.vae.config.block_out_channels) - 1)
383
+ non_depth_image_files = []
384
+ image_paths_by_path = {}
385
+
386
+ for path in paths:
387
+ #if path is list
388
+ if isinstance(path, list):
389
+ img = Path(path[0])
390
+ else:
391
+ img = Path(path)
392
+ if self.get_depth_image_path(img).exists():
393
+ continue
394
+ else:
395
+ non_depth_image_files.append(img)
396
+ image_objects = []
397
+ for image_path in non_depth_image_files:
398
+ image_instance = Image.open(image_path)
399
+ if not image_instance.mode == "RGB":
400
+ image_instance = image_instance.convert("RGB")
401
+ image_instance = self.pipeline.feature_extractor(
402
+ image_instance, return_tensors="pt"
403
+ ).pixel_values
404
+
405
+ image_instance = image_instance.to(self.accelerator.device)
406
+ image_objects.append((image_path, image_instance))
407
+
408
+ for image_path, image_instance in image_objects:
409
+ path = image_path.parent
410
+ ogImg = Image.open(image_path)
411
+ ogImg_x = ogImg.size[0]
412
+ ogImg_y = ogImg.size[1]
413
+ depth_map = self.pipeline.depth_estimator(image_instance).predicted_depth
414
+ depth_min = torch.amin(depth_map, dim=[0, 1, 2], keepdim=True)
415
+ depth_max = torch.amax(depth_map, dim=[0, 1, 2], keepdim=True)
416
+ depth_map = torch.nn.functional.interpolate(depth_map.unsqueeze(1),size=(ogImg_y, ogImg_x),mode="bicubic",align_corners=False,)
417
+
418
+ depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
419
+ depth_map = depth_map[0,:,:]
420
+ depth_map_image = transforms.ToPILImage()(depth_map)
421
+ depth_map_image = depth_map_image.filter(ImageFilter.GaussianBlur(radius=1))
422
+ depth_map_image.save(self.get_depth_image_path(image_path))
423
+ #quit()
424
+ return 2 ** (len(self.pipeline.vae.config.block_out_channels) - 1)
425
+
426
+ def get_depth_image_path(self,image_path):
427
+ #if image_path is a string, convert it to a Path object
428
+ if isinstance(image_path, str):
429
+ image_path = Path(image_path)
430
+ return image_path.parent / f"{image_path.stem}-depth.png"
431
+
432
+ def fix_nans_(param, name=None, stats=None):
433
+ (std, mean) = stats or (1, 0)
434
+ tqdm.write(name, param.shape, param.dtype, mean, std)
435
+ param.data = torch.where(param.data.isnan(), torch.randn_like(param.data) * std + mean, param.data).detach()