Spaces:
Runtime error
Runtime error
Damian Stewart
commited on
Commit
•
50b9662
1
Parent(s):
6067469
add tensorboard, validation, sample output
Browse files- StableDiffuser.py +32 -16
- app.py +111 -38
- memory_efficiency.py +1 -1
- requirements.txt +1 -0
- train.py +162 -57
StableDiffuser.py
CHANGED
@@ -36,11 +36,13 @@ class StableDiffuser(torch.nn.Module):
|
|
36 |
def __init__(self,
|
37 |
scheduler='LMS',
|
38 |
keep_pipeline=False,
|
|
|
39 |
repo_id_or_path="CompVis/stable-diffusion-v1-4"):
|
40 |
|
41 |
super().__init__()
|
42 |
|
43 |
self.pipeline = StableDiffusionPipeline.from_pretrained(repo_id_or_path)
|
|
|
44 |
|
45 |
self.vae = self.pipeline.vae
|
46 |
self.unet = self.pipeline.unet
|
@@ -60,8 +62,10 @@ class StableDiffuser(torch.nn.Module):
|
|
60 |
if not keep_pipeline:
|
61 |
del self.pipeline
|
62 |
|
63 |
-
def get_noise(self, batch_size, width, height, generator=None):
|
64 |
param = list(self.parameters())[0]
|
|
|
|
|
65 |
return torch.randn(
|
66 |
(batch_size, self.unet.config.in_channels, width // 8, height // 8),
|
67 |
generator=generator).type(param.dtype).to(param.device)
|
@@ -95,16 +99,20 @@ class StableDiffuser(torch.nn.Module):
|
|
95 |
def set_scheduler_timesteps(self, n_steps):
|
96 |
self.scheduler.set_timesteps(n_steps, device=self.unet.device)
|
97 |
|
98 |
-
def get_initial_latents(self, n_imgs, height, width, n_prompts, generator=None):
|
|
|
|
|
99 |
noise = self.get_noise(n_imgs, height, width, generator=generator).repeat(n_prompts, 1, 1, 1)
|
100 |
latents = noise * self.scheduler.init_noise_sigma
|
101 |
return latents
|
102 |
|
103 |
-
def
|
104 |
text_tokens = self.text_tokenize(prompts)
|
105 |
text_embeddings = self.text_encode(text_tokens)
|
106 |
if negative_prompts is None:
|
107 |
-
negative_prompts = [
|
|
|
|
|
108 |
unconditional_tokens = self.text_tokenize(negative_prompts)
|
109 |
unconditional_embeddings = self.text_encode(unconditional_tokens)
|
110 |
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]).repeat_interleave(n_imgs, dim=0)
|
@@ -136,12 +144,12 @@ class StableDiffuser(torch.nn.Module):
|
|
136 |
@torch.no_grad()
|
137 |
def diffusion(self,
|
138 |
latents,
|
139 |
-
|
140 |
end_iteration=1000,
|
141 |
start_iteration=0,
|
142 |
return_steps=False,
|
143 |
pred_x0=False,
|
144 |
-
trace_args=None,
|
145 |
show_progress=True,
|
146 |
use_amp=False,
|
147 |
**kwargs):
|
@@ -159,7 +167,7 @@ class StableDiffuser(torch.nn.Module):
|
|
159 |
noise_pred = self.predict_noise(
|
160 |
iteration,
|
161 |
latents,
|
162 |
-
|
163 |
**kwargs)
|
164 |
|
165 |
# compute the previous noisy sample x_t -> x_t-1
|
@@ -182,30 +190,38 @@ class StableDiffuser(torch.nn.Module):
|
|
182 |
|
183 |
@torch.no_grad()
|
184 |
def __call__(self,
|
185 |
-
prompts,
|
186 |
-
negative_prompts,
|
187 |
-
|
188 |
-
|
|
|
189 |
n_steps=50,
|
190 |
n_imgs=1,
|
191 |
end_iteration=None,
|
192 |
generator=None,
|
|
|
193 |
**kwargs
|
194 |
):
|
195 |
|
196 |
assert 0 <= n_steps <= 1000
|
197 |
|
198 |
-
if
|
199 |
-
prompts
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
self.set_scheduler_timesteps(n_steps)
|
202 |
-
latents = self.get_initial_latents(n_imgs, height, width,
|
203 |
-
|
204 |
end_iteration = end_iteration or n_steps
|
205 |
latents_steps, trace_steps = self.diffusion(
|
206 |
latents,
|
207 |
-
|
208 |
end_iteration=end_iteration,
|
|
|
209 |
**kwargs
|
210 |
)
|
211 |
|
|
|
36 |
def __init__(self,
|
37 |
scheduler='LMS',
|
38 |
keep_pipeline=False,
|
39 |
+
native_img_size=512,
|
40 |
repo_id_or_path="CompVis/stable-diffusion-v1-4"):
|
41 |
|
42 |
super().__init__()
|
43 |
|
44 |
self.pipeline = StableDiffusionPipeline.from_pretrained(repo_id_or_path)
|
45 |
+
self.native_image_size = native_img_size
|
46 |
|
47 |
self.vae = self.pipeline.vae
|
48 |
self.unet = self.pipeline.unet
|
|
|
62 |
if not keep_pipeline:
|
63 |
del self.pipeline
|
64 |
|
65 |
+
def get_noise(self, batch_size, width=None, height=None, generator=None):
|
66 |
param = list(self.parameters())[0]
|
67 |
+
width = width or self.native_image_size
|
68 |
+
height = height or self.native_image_size
|
69 |
return torch.randn(
|
70 |
(batch_size, self.unet.config.in_channels, width // 8, height // 8),
|
71 |
generator=generator).type(param.dtype).to(param.device)
|
|
|
99 |
def set_scheduler_timesteps(self, n_steps):
|
100 |
self.scheduler.set_timesteps(n_steps, device=self.unet.device)
|
101 |
|
102 |
+
def get_initial_latents(self, n_imgs, height=None, width=None, n_prompts=1, generator=None):
|
103 |
+
height = height or self.native_image_size
|
104 |
+
width = width or self.native_image_size
|
105 |
noise = self.get_noise(n_imgs, height, width, generator=generator).repeat(n_prompts, 1, 1, 1)
|
106 |
latents = noise * self.scheduler.init_noise_sigma
|
107 |
return latents
|
108 |
|
109 |
+
def get_cond_and_uncond_embeddings(self, prompts, negative_prompts=None, n_imgs=1):
|
110 |
text_tokens = self.text_tokenize(prompts)
|
111 |
text_embeddings = self.text_encode(text_tokens)
|
112 |
if negative_prompts is None:
|
113 |
+
negative_prompts = []
|
114 |
+
while len(negative_prompts) < len(prompts):
|
115 |
+
negative_prompts.append("")
|
116 |
unconditional_tokens = self.text_tokenize(negative_prompts)
|
117 |
unconditional_embeddings = self.text_encode(unconditional_tokens)
|
118 |
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]).repeat_interleave(n_imgs, dim=0)
|
|
|
144 |
@torch.no_grad()
|
145 |
def diffusion(self,
|
146 |
latents,
|
147 |
+
uncond_and_cond_embeddings,
|
148 |
end_iteration=1000,
|
149 |
start_iteration=0,
|
150 |
return_steps=False,
|
151 |
pred_x0=False,
|
152 |
+
trace_args=None,
|
153 |
show_progress=True,
|
154 |
use_amp=False,
|
155 |
**kwargs):
|
|
|
167 |
noise_pred = self.predict_noise(
|
168 |
iteration,
|
169 |
latents,
|
170 |
+
uncond_and_cond_embeddings,
|
171 |
**kwargs)
|
172 |
|
173 |
# compute the previous noisy sample x_t -> x_t-1
|
|
|
190 |
|
191 |
@torch.no_grad()
|
192 |
def __call__(self,
|
193 |
+
prompts=None,
|
194 |
+
negative_prompts=None,
|
195 |
+
combined_embeddings=None, # uncond first, then cond
|
196 |
+
width=None,
|
197 |
+
height=None,
|
198 |
n_steps=50,
|
199 |
n_imgs=1,
|
200 |
end_iteration=None,
|
201 |
generator=None,
|
202 |
+
use_amp=False,
|
203 |
**kwargs
|
204 |
):
|
205 |
|
206 |
assert 0 <= n_steps <= 1000
|
207 |
|
208 |
+
if combined_embeddings is None:
|
209 |
+
assert prompts is not None, "missing prompts or combined_embeddings"
|
210 |
+
combined_embeddings = diffuser.get_cond_and_uncond_embeddings(prompts, negative_prompts, n_imgs=n_imgs)
|
211 |
+
|
212 |
+
width = width or self.native_image_size
|
213 |
+
height = height or self.native_image_size
|
214 |
+
num_prompts = combined_embeddings.shape[0] // 2
|
215 |
|
216 |
self.set_scheduler_timesteps(n_steps)
|
217 |
+
latents = self.get_initial_latents(n_imgs, height, width, num_prompts, generator=generator)
|
218 |
+
|
219 |
end_iteration = end_iteration or n_steps
|
220 |
latents_steps, trace_steps = self.diffusion(
|
221 |
latents,
|
222 |
+
combined_embeddings,
|
223 |
end_iteration=end_iteration,
|
224 |
+
use_amp=use_amp,
|
225 |
**kwargs
|
226 |
)
|
227 |
|
app.py
CHANGED
@@ -7,12 +7,13 @@ from diffusers.utils import is_xformers_available
|
|
7 |
from finetuning import FineTunedModel
|
8 |
from StableDiffuser import StableDiffuser
|
9 |
from memory_efficiency import MemoryEfficiencyWrapper
|
10 |
-
from train import train
|
11 |
|
12 |
import os
|
13 |
|
|
|
14 |
def populate_model_map():
|
15 |
-
model_map
|
16 |
for model_file in os.listdir('models'):
|
17 |
path = 'models/' + model_file
|
18 |
if any([existing_path == path for existing_path in model_map.values()]):
|
@@ -28,6 +29,7 @@ SHARED_UI_WARNING = f'''## Attention - Training using the ESD-u method does not
|
|
28 |
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
|
29 |
'''
|
30 |
|
|
|
31 |
|
32 |
class Demo:
|
33 |
|
@@ -70,24 +72,11 @@ class Demo:
|
|
70 |
self.negative_prompt_input_infr = gr.Text(
|
71 |
label="Negative prompt"
|
72 |
)
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
choices= list(model_map.keys()),
|
79 |
-
value='Van Gogh',
|
80 |
-
interactive=True
|
81 |
-
)
|
82 |
-
self.model_reload_button = gr.Button(
|
83 |
-
value="🔄",
|
84 |
-
interactive=True
|
85 |
-
)
|
86 |
-
|
87 |
-
self.seed_infr = gr.Number(
|
88 |
-
label="Seed",
|
89 |
-
value=42
|
90 |
-
)
|
91 |
self.img_width_infr = gr.Slider(
|
92 |
label="Image width",
|
93 |
minimum=256,
|
@@ -95,7 +84,6 @@ class Demo:
|
|
95 |
value=512,
|
96 |
step=64
|
97 |
)
|
98 |
-
|
99 |
self.img_height_infr = gr.Slider(
|
100 |
label="Image height",
|
101 |
minimum=256,
|
@@ -104,6 +92,18 @@ class Demo:
|
|
104 |
step=64
|
105 |
)
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
self.base_repo_id_or_path_input_infr = gr.Text(
|
108 |
label="Base model",
|
109 |
value="CompVis/stable-diffusion-v1-4",
|
@@ -131,14 +131,12 @@ class Demo:
|
|
131 |
with gr.Tab("Train") as training_column:
|
132 |
|
133 |
with gr.Row():
|
134 |
-
|
135 |
self.explain_train= gr.Markdown(interactive=False,
|
136 |
value='In this part you can erase any concept from Stable Diffusion. Enter a prompt for the concept or style you want to erase, and select ESD-x if you want to focus erasure on prompts that mention the concept explicitly. [NOTE: ESD-u is currently unavailable in this space. But you can duplicate the space and run it on GPU with VRAM >40GB for enabling ESD-u]. With default settings, it takes about 15 minutes to fine-tune the model; then you can try inference above or download the weights. The training code used here is slightly different than the code tested in the original paper. Code and details are at [github link](https://github.com/rohitgandikota/erasing).')
|
137 |
|
138 |
with gr.Row():
|
139 |
|
140 |
with gr.Column(scale=3):
|
141 |
-
|
142 |
self.train_model_input = gr.Text(
|
143 |
label="Model to Edit",
|
144 |
value="CompVis/stable-diffusion-v1-4",
|
@@ -196,7 +194,7 @@ class Demo:
|
|
196 |
)
|
197 |
self.train_save_every_input = gr.Number(
|
198 |
value=-1,
|
199 |
-
label="Save
|
200 |
info="If >0, save the model throughout training at the given step interval."
|
201 |
)
|
202 |
|
@@ -210,6 +208,28 @@ class Demo:
|
|
210 |
self.train_use_gradient_checkpointing_input = gr.Checkbox(
|
211 |
label="Gradient checkpointing", value=False)
|
212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
with gr.Column(scale=1):
|
214 |
|
215 |
self.train_status = gr.Button(value='', variant='primary', label='Status', interactive=False)
|
@@ -219,7 +239,7 @@ class Demo:
|
|
219 |
)
|
220 |
|
221 |
self.train_cancel_button = gr.Button(
|
222 |
-
value="Cancel
|
223 |
)
|
224 |
|
225 |
self.download = gr.Files()
|
@@ -260,6 +280,7 @@ class Demo:
|
|
260 |
value='', variant='primary', label='Status', interactive=False)
|
261 |
self.export_button = gr.Button(
|
262 |
value="Export")
|
|
|
263 |
|
264 |
self.infr_button.click(self.inference, inputs = [
|
265 |
self.prompt_input_infr,
|
@@ -292,10 +313,16 @@ class Demo:
|
|
292 |
self.train_use_gradient_checkpointing_input,
|
293 |
self.train_seed_input,
|
294 |
self.train_save_every_input,
|
|
|
|
|
|
|
|
|
295 |
],
|
296 |
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
297 |
)
|
298 |
-
self.train_cancel_button.click(
|
|
|
|
|
299 |
|
300 |
self.export_button.click(self.export, inputs = [
|
301 |
self.model_dropdown_export,
|
@@ -303,23 +330,51 @@ class Demo:
|
|
303 |
self.save_path_input_export,
|
304 |
self.save_half_export
|
305 |
],
|
306 |
-
outputs=[self.export_status]
|
307 |
)
|
308 |
|
309 |
def reload_models(self, model_dropdown):
|
310 |
current_model_name = model_dropdown
|
311 |
global model_map
|
312 |
-
|
313 |
-
return [
|
|
|
|
|
|
|
|
|
314 |
|
315 |
def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
|
316 |
use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
|
317 |
seed=-1, save_every=-1,
|
318 |
-
|
319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
if self.training:
|
321 |
return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
|
322 |
|
|
|
|
|
323 |
print(f"Training {repo_id_or_path} at {img_size} to remove '{prompt}'.")
|
324 |
print(f" {train_method}, negative guidance {neg_guidance}, lr {lr}, {iterations} iterations.")
|
325 |
print(f" {'✅' if use_gradient_checkpointing else '❌'} gradient checkpointing")
|
@@ -348,23 +403,38 @@ class Demo:
|
|
348 |
break
|
349 |
# repeat until a not-in-use path is found
|
350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
try:
|
352 |
self.training = True
|
353 |
self.train_cancel_button.update(interactive=True)
|
354 |
-
train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
|
355 |
use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing,
|
356 |
-
seed=int(seed),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
357 |
finally:
|
358 |
self.training = False
|
359 |
self.train_cancel_button.update(interactive=False)
|
360 |
|
361 |
torch.cuda.empty_cache()
|
362 |
|
363 |
-
new_model_name
|
364 |
-
|
365 |
|
366 |
return [gr.update(interactive=True, value='Train'),
|
367 |
-
gr.update(value=
|
368 |
save_path,
|
369 |
gr.Dropdown.update(choices=list(model_map.keys()), value=new_model_name)]
|
370 |
|
@@ -373,7 +443,7 @@ class Demo:
|
|
373 |
checkpoint = torch.load(model_path)
|
374 |
diffuser = StableDiffuser(scheduler='DDIM',
|
375 |
keep_pipeline=True,
|
376 |
-
repo_id_or_path=base_repo_id_or_path
|
377 |
).eval()
|
378 |
finetuner = FineTunedModel.from_checkpoint(diffuser, checkpoint).eval()
|
379 |
with finetuner:
|
@@ -381,7 +451,10 @@ class Demo:
|
|
381 |
diffuser = diffuser.half()
|
382 |
diffuser.pipeline.to('cpu', torch_dtype=torch.float16)
|
383 |
diffuser.pipeline.save_pretrained(save_path)
|
384 |
-
|
|
|
|
|
|
|
385 |
|
386 |
|
387 |
def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
|
|
|
7 |
from finetuning import FineTunedModel
|
8 |
from StableDiffuser import StableDiffuser
|
9 |
from memory_efficiency import MemoryEfficiencyWrapper
|
10 |
+
from train import train, training_should_cancel
|
11 |
|
12 |
import os
|
13 |
|
14 |
+
model_map = {}
|
15 |
def populate_model_map():
|
16 |
+
global model_map
|
17 |
for model_file in os.listdir('models'):
|
18 |
path = 'models/' + model_file
|
19 |
if any([existing_path == path for existing_path in model_map.values()]):
|
|
|
29 |
<center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
|
30 |
'''
|
31 |
|
32 |
+
# work around Gradio's weird threading
|
33 |
|
34 |
class Demo:
|
35 |
|
|
|
72 |
self.negative_prompt_input_infr = gr.Text(
|
73 |
label="Negative prompt"
|
74 |
)
|
75 |
+
self.seed_infr = gr.Number(
|
76 |
+
label="Seed",
|
77 |
+
value=42
|
78 |
+
)
|
79 |
+
with gr.Row(scale=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
self.img_width_infr = gr.Slider(
|
81 |
label="Image width",
|
82 |
minimum=256,
|
|
|
84 |
value=512,
|
85 |
step=64
|
86 |
)
|
|
|
87 |
self.img_height_infr = gr.Slider(
|
88 |
label="Image height",
|
89 |
minimum=256,
|
|
|
92 |
step=64
|
93 |
)
|
94 |
|
95 |
+
with gr.Row(scale=1):
|
96 |
+
self.model_dropdown = gr.Dropdown(
|
97 |
+
label="ESD Model",
|
98 |
+
choices= list(model_map.keys()),
|
99 |
+
value='Van Gogh',
|
100 |
+
interactive=True
|
101 |
+
)
|
102 |
+
self.model_reload_button = gr.Button(
|
103 |
+
value="🔄",
|
104 |
+
interactive=True
|
105 |
+
)
|
106 |
+
|
107 |
self.base_repo_id_or_path_input_infr = gr.Text(
|
108 |
label="Base model",
|
109 |
value="CompVis/stable-diffusion-v1-4",
|
|
|
131 |
with gr.Tab("Train") as training_column:
|
132 |
|
133 |
with gr.Row():
|
|
|
134 |
self.explain_train= gr.Markdown(interactive=False,
|
135 |
value='In this part you can erase any concept from Stable Diffusion. Enter a prompt for the concept or style you want to erase, and select ESD-x if you want to focus erasure on prompts that mention the concept explicitly. [NOTE: ESD-u is currently unavailable in this space. But you can duplicate the space and run it on GPU with VRAM >40GB for enabling ESD-u]. With default settings, it takes about 15 minutes to fine-tune the model; then you can try inference above or download the weights. The training code used here is slightly different than the code tested in the original paper. Code and details are at [github link](https://github.com/rohitgandikota/erasing).')
|
136 |
|
137 |
with gr.Row():
|
138 |
|
139 |
with gr.Column(scale=3):
|
|
|
140 |
self.train_model_input = gr.Text(
|
141 |
label="Model to Edit",
|
142 |
value="CompVis/stable-diffusion-v1-4",
|
|
|
194 |
)
|
195 |
self.train_save_every_input = gr.Number(
|
196 |
value=-1,
|
197 |
+
label="Save Every N Steps",
|
198 |
info="If >0, save the model throughout training at the given step interval."
|
199 |
)
|
200 |
|
|
|
208 |
self.train_use_gradient_checkpointing_input = gr.Checkbox(
|
209 |
label="Gradient checkpointing", value=False)
|
210 |
|
211 |
+
self.train_validation_prompts = gr.TextArea(
|
212 |
+
label="Validation Prompts",
|
213 |
+
placeholder="Probably, you want to put the \"Prompt to Erase\" in here as the first entry...",
|
214 |
+
value='',
|
215 |
+
info="Prompts for producing validation graphs, one per line."
|
216 |
+
)
|
217 |
+
self.train_sample_positive_prompts = gr.TextArea(
|
218 |
+
label="Sample Prompts",
|
219 |
+
value='',
|
220 |
+
info="Positive prompts for generating sample images, one per line."
|
221 |
+
)
|
222 |
+
self.train_sample_negative_prompts = gr.TextArea(
|
223 |
+
label="Sample Negative Prompts",
|
224 |
+
value='',
|
225 |
+
info="Negative prompts for use when generating sample images. One for each positive prompt, or leave empty for none."
|
226 |
+
)
|
227 |
+
self.train_validate_every_n_steps = gr.Number(
|
228 |
+
label="Validate Every N Steps",
|
229 |
+
value=20,
|
230 |
+
info="Validation and sample generation will be run at intervals of this many steps"
|
231 |
+
)
|
232 |
+
|
233 |
with gr.Column(scale=1):
|
234 |
|
235 |
self.train_status = gr.Button(value='', variant='primary', label='Status', interactive=False)
|
|
|
239 |
)
|
240 |
|
241 |
self.train_cancel_button = gr.Button(
|
242 |
+
value="Cancel Training"
|
243 |
)
|
244 |
|
245 |
self.download = gr.Files()
|
|
|
280 |
value='', variant='primary', label='Status', interactive=False)
|
281 |
self.export_button = gr.Button(
|
282 |
value="Export")
|
283 |
+
self.export_download = gr.Files()
|
284 |
|
285 |
self.infr_button.click(self.inference, inputs = [
|
286 |
self.prompt_input_infr,
|
|
|
313 |
self.train_use_gradient_checkpointing_input,
|
314 |
self.train_seed_input,
|
315 |
self.train_save_every_input,
|
316 |
+
self.train_validation_prompts,
|
317 |
+
self.train_sample_positive_prompts,
|
318 |
+
self.train_sample_negative_prompts,
|
319 |
+
self.train_validate_every_n_steps
|
320 |
],
|
321 |
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
322 |
)
|
323 |
+
self.train_cancel_button.click(self.cancel_training,
|
324 |
+
inputs=[],
|
325 |
+
outputs=[self.train_cancel_button])
|
326 |
|
327 |
self.export_button.click(self.export, inputs = [
|
328 |
self.model_dropdown_export,
|
|
|
330 |
self.save_path_input_export,
|
331 |
self.save_half_export
|
332 |
],
|
333 |
+
outputs=[self.export_button, self.export_status, self.export_download]
|
334 |
)
|
335 |
|
336 |
def reload_models(self, model_dropdown):
|
337 |
current_model_name = model_dropdown
|
338 |
global model_map
|
339 |
+
populate_model_map()
|
340 |
+
return [self.model_dropdown.update(choices=list(model_map.keys()), value=current_model_name)]
|
341 |
+
|
342 |
+
def cancel_training(self):
|
343 |
+
train.training_should_cancel = True
|
344 |
+
return [gr.update(value="Cancelling...", interactive=False)]
|
345 |
|
346 |
def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
|
347 |
use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
|
348 |
seed=-1, save_every=-1,
|
349 |
+
validation_prompts: str=None, sample_positive_prompts: str=None, sample_negative_prompts: str=None, validate_every_n_steps=-1,
|
350 |
+
pbar=gr.Progress(track_tqdm=True)):
|
351 |
+
"""
|
352 |
+
|
353 |
+
:param repo_id_or_path:
|
354 |
+
:param img_size:
|
355 |
+
:param prompt:
|
356 |
+
:param train_method:
|
357 |
+
:param neg_guidance:
|
358 |
+
:param iterations:
|
359 |
+
:param lr:
|
360 |
+
:param use_adamw8bit:
|
361 |
+
:param use_xformers:
|
362 |
+
:param use_amp:
|
363 |
+
:param use_gradient_checkpointing:
|
364 |
+
:param seed:
|
365 |
+
:param save_every:
|
366 |
+
:param validation_prompts: split on \n
|
367 |
+
:param sample_positive_prompts: split on \n
|
368 |
+
:param sample_negative_prompts: split on \n
|
369 |
+
:param validate_every_n_steps: split on \n
|
370 |
+
:param pbar:
|
371 |
+
:return:
|
372 |
+
"""
|
373 |
if self.training:
|
374 |
return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
|
375 |
|
376 |
+
train.training_should_cancel = False
|
377 |
+
|
378 |
print(f"Training {repo_id_or_path} at {img_size} to remove '{prompt}'.")
|
379 |
print(f" {train_method}, negative guidance {neg_guidance}, lr {lr}, {iterations} iterations.")
|
380 |
print(f" {'✅' if use_gradient_checkpointing else '❌'} gradient checkpointing")
|
|
|
403 |
break
|
404 |
# repeat until a not-in-use path is found
|
405 |
|
406 |
+
validation_prompts = [] if validation_prompts is None else validation_prompts.split('\n')
|
407 |
+
sample_positive_prompts = [] if sample_positive_prompts is None else sample_positive_prompts.split('\n')
|
408 |
+
sample_negative_prompts = [] if sample_negative_prompts is None else sample_negative_prompts.split('\n')
|
409 |
+
print(f"validation prompts: {validation_prompts}")
|
410 |
+
print(f"sample positive prompts: {sample_positive_prompts}")
|
411 |
+
print(f"sample negative prompts: {sample_negative_prompts}")
|
412 |
+
|
413 |
try:
|
414 |
self.training = True
|
415 |
self.train_cancel_button.update(interactive=True)
|
416 |
+
save_path = train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
|
417 |
use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing,
|
418 |
+
seed=int(seed), save_every_n_steps=int(save_every),
|
419 |
+
validate_every_n_steps=validate_every_n_steps, validation_prompts=validation_prompts,
|
420 |
+
sample_positive_prompts=sample_positive_prompts, sample_negative_prompts=sample_negative_prompts)
|
421 |
+
if save_path is None:
|
422 |
+
new_model_name = None
|
423 |
+
finished_message = "Training cancelled."
|
424 |
+
else:
|
425 |
+
new_model_name = f'{os.path.basename(save_path)}'
|
426 |
+
finished_message = f'Done Training! Try your model ({new_model_name}) in the "Test" tab'
|
427 |
finally:
|
428 |
self.training = False
|
429 |
self.train_cancel_button.update(interactive=False)
|
430 |
|
431 |
torch.cuda.empty_cache()
|
432 |
|
433 |
+
if new_model_name is not None:
|
434 |
+
model_map[new_model_name] = save_path
|
435 |
|
436 |
return [gr.update(interactive=True, value='Train'),
|
437 |
+
gr.update(value=finished_message),
|
438 |
save_path,
|
439 |
gr.Dropdown.update(choices=list(model_map.keys()), value=new_model_name)]
|
440 |
|
|
|
443 |
checkpoint = torch.load(model_path)
|
444 |
diffuser = StableDiffuser(scheduler='DDIM',
|
445 |
keep_pipeline=True,
|
446 |
+
repo_id_or_path=base_repo_id_or_path,
|
447 |
).eval()
|
448 |
finetuner = FineTunedModel.from_checkpoint(diffuser, checkpoint).eval()
|
449 |
with finetuner:
|
|
|
451 |
diffuser = diffuser.half()
|
452 |
diffuser.pipeline.to('cpu', torch_dtype=torch.float16)
|
453 |
diffuser.pipeline.save_pretrained(save_path)
|
454 |
+
|
455 |
+
return [gr.update(interactive=True, value='Export'),
|
456 |
+
gr.update(value=f'Done Exporting!'),
|
457 |
+
save_path]
|
458 |
|
459 |
|
460 |
def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
|
memory_efficiency.py
CHANGED
@@ -44,7 +44,7 @@ class MemoryEfficiencyWrapper:
|
|
44 |
print("xformers disabled via arg, using attention slicing instead")
|
45 |
self.diffuser.unet.set_attention_slice("auto")
|
46 |
|
47 |
-
self.diffuser.vae = self.diffuser.vae.to(self.diffuser.vae.device, dtype=torch.float16 if self.use_amp else torch.float32)
|
48 |
self.diffuser.unet = self.diffuser.unet.to(self.diffuser.unet.device, dtype=torch.float32)
|
49 |
|
50 |
try:
|
|
|
44 |
print("xformers disabled via arg, using attention slicing instead")
|
45 |
self.diffuser.unet.set_attention_slice("auto")
|
46 |
|
47 |
+
#self.diffuser.vae = self.diffuser.vae.to(self.diffuser.vae.device, dtype=torch.float16 if self.use_amp else torch.float32)
|
48 |
self.diffuser.unet = self.diffuser.unet.to(self.diffuser.unet.device, dtype=torch.float32)
|
49 |
|
50 |
try:
|
requirements.txt
CHANGED
@@ -9,3 +9,4 @@ git+https://github.com/davidbau/baukit.git
|
|
9 |
xformers
|
10 |
bitsandbytes==0.38.1
|
11 |
safetensors
|
|
|
|
9 |
xformers
|
10 |
bitsandbytes==0.38.1
|
11 |
safetensors
|
12 |
+
tensorboard
|
train.py
CHANGED
@@ -1,7 +1,10 @@
|
|
|
|
1 |
import random
|
2 |
|
3 |
from accelerate.utils import set_seed
|
|
|
4 |
from torch.cuda.amp import autocast
|
|
|
5 |
|
6 |
from StableDiffuser import StableDiffuser
|
7 |
from finetuning import FineTunedModel
|
@@ -10,13 +13,90 @@ from tqdm import tqdm
|
|
10 |
|
11 |
from isolate_rng import isolate_rng
|
12 |
from memory_efficiency import MemoryEfficiencyWrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
|
16 |
-
use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
nsteps = 50
|
19 |
-
|
|
|
|
|
20 |
|
21 |
memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
|
22 |
use_gradient_checkpointing=use_gradient_checkpointing )
|
@@ -40,16 +120,18 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
|
|
40 |
pbar = tqdm(range(iterations))
|
41 |
|
42 |
with torch.no_grad():
|
43 |
-
neutral_text_embeddings = diffuser.
|
44 |
-
positive_text_embeddings = diffuser.
|
|
|
|
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
del diffuser.tokenizer
|
49 |
|
50 |
-
|
|
|
51 |
|
52 |
-
|
53 |
|
54 |
if seed == -1:
|
55 |
seed = random.randint(0, 2 ** 30)
|
@@ -58,65 +140,88 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
|
|
58 |
prev_losses = []
|
59 |
start_loss = None
|
60 |
max_prev_loss_count = 10
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
|
69 |
-
|
70 |
-
latents_steps, _ = diffuser.diffusion(
|
71 |
-
latents,
|
72 |
-
positive_text_embeddings,
|
73 |
-
start_iteration=0,
|
74 |
-
end_iteration=iteration,
|
75 |
-
guidance_scale=3,
|
76 |
-
show_progress=False,
|
77 |
-
use_amp=use_amp
|
78 |
-
)
|
79 |
-
|
80 |
-
diffuser.set_scheduler_timesteps(1000)
|
81 |
-
iteration = int(iteration / nsteps * 1000)
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
|
86 |
|
87 |
-
|
88 |
-
with autocast(enabled=use_amp):
|
89 |
-
negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
|
90 |
|
91 |
-
|
92 |
-
|
|
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
optimizer.zero_grad()
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
if len(prev_losses) > max_prev_loss_count:
|
102 |
-
prev_losses.pop(0)
|
103 |
-
if start_loss is None:
|
104 |
-
start_loss = prev_losses[-1]
|
105 |
-
if len(prev_losses) >= max_prev_loss_count:
|
106 |
-
moving_average_loss = sum(prev_losses) / len(prev_losses)
|
107 |
-
print(
|
108 |
-
f"step {i}: loss={loss.item()} (avg={moving_average_loss.item()}, start ∆={(moving_average_loss - start_loss).item()}")
|
109 |
-
else:
|
110 |
-
print(f"step {i}: loss={loss.item()}")
|
111 |
|
112 |
-
|
113 |
-
|
|
|
|
|
114 |
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
-
del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents
|
118 |
|
119 |
-
torch.cuda.empty_cache()
|
120 |
if __name__ == '__main__':
|
121 |
|
122 |
import argparse
|
|
|
1 |
+
import os.path
|
2 |
import random
|
3 |
|
4 |
from accelerate.utils import set_seed
|
5 |
+
from diffusers import StableDiffusionPipeline
|
6 |
from torch.cuda.amp import autocast
|
7 |
+
from torchvision import transforms
|
8 |
|
9 |
from StableDiffuser import StableDiffuser
|
10 |
from finetuning import FineTunedModel
|
|
|
13 |
|
14 |
from isolate_rng import isolate_rng
|
15 |
from memory_efficiency import MemoryEfficiencyWrapper
|
16 |
+
from torch.utils.tensorboard import SummaryWriter
|
17 |
+
|
18 |
+
training_should_cancel = False
|
19 |
+
|
20 |
+
def validate(diffuser: StableDiffuser, finetuner: FineTunedModel,
|
21 |
+
validation_embeddings: torch.FloatTensor,
|
22 |
+
neutral_embeddings: torch.FloatTensor,
|
23 |
+
sample_embeddings: torch.FloatTensor,
|
24 |
+
logger: SummaryWriter, use_amp: bool,
|
25 |
+
global_step: int,
|
26 |
+
validation_seed: int = 555,
|
27 |
+
):
|
28 |
+
print("validating...")
|
29 |
+
with isolate_rng(include_cuda=True), torch.no_grad():
|
30 |
+
set_seed(validation_seed)
|
31 |
+
criteria = torch.nn.MSELoss()
|
32 |
+
negative_guidance = 1
|
33 |
+
val_count = 5
|
34 |
+
|
35 |
+
nsteps=50
|
36 |
+
num_validation_prompts = validation_embeddings.shape[0] // 2
|
37 |
+
for i in range(0, num_validation_prompts):
|
38 |
+
accumulated_loss = None
|
39 |
+
this_validation_embeddings = validation_embeddings[i*2:i*2+2]
|
40 |
+
for j in range(val_count):
|
41 |
+
iteration = random.randint(1, nsteps)
|
42 |
+
diffused_latents = get_diffused_latents(diffuser, nsteps, this_validation_embeddings, iteration, use_amp)
|
43 |
+
|
44 |
+
with autocast(enabled=use_amp):
|
45 |
+
positive_latents = diffuser.predict_noise(iteration, diffused_latents, this_validation_embeddings, guidance_scale=1)
|
46 |
+
neutral_latents = diffuser.predict_noise(iteration, diffused_latents, neutral_embeddings, guidance_scale=1)
|
47 |
|
48 |
+
with finetuner, autocast(enabled=use_amp):
|
49 |
+
negative_latents = diffuser.predict_noise(iteration, diffused_latents, this_validation_embeddings, guidance_scale=1)
|
50 |
+
|
51 |
+
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
|
52 |
+
accumulated_loss = (accumulated_loss or 0) + loss.item()
|
53 |
+
logger.add_scalar(f"loss/val_{i}", accumulated_loss/val_count, global_step=global_step)
|
54 |
+
|
55 |
+
num_samples = sample_embeddings.shape[0] // 2
|
56 |
+
for i in range(0, num_samples):
|
57 |
+
print(f'making sample {i}...')
|
58 |
+
with finetuner:
|
59 |
+
pipeline = StableDiffusionPipeline(vae=diffuser.vae,
|
60 |
+
text_encoder=diffuser.text_encoder,
|
61 |
+
tokenizer=diffuser.tokenizer,
|
62 |
+
unet=diffuser.unet,
|
63 |
+
scheduler=diffuser.scheduler,
|
64 |
+
safety_checker=None,
|
65 |
+
feature_extractor=None,
|
66 |
+
requires_safety_checker=False)
|
67 |
+
images = pipeline(prompt_embeds=sample_embeddings[i*2+1:i*2+2], negative_prompt_embeds=sample_embeddings[i*2:i*2+1],
|
68 |
+
num_inference_steps=50)
|
69 |
+
image_tensor = transforms.ToTensor()(images.images[0])
|
70 |
+
logger.add_image(f"samples/{i}", img_tensor=image_tensor, global_step=global_step)
|
71 |
+
|
72 |
+
"""
|
73 |
+
with finetuner, torch.cuda.amp.autocast(enabled=use_amp):
|
74 |
+
images = diffuser(
|
75 |
+
combined_embeddings=sample_embeddings[i*2:i*2+2],
|
76 |
+
n_steps=50
|
77 |
+
)
|
78 |
+
logger.add_images(f"samples/{i}", images)
|
79 |
+
"""
|
80 |
+
|
81 |
+
torch.cuda.empty_cache()
|
82 |
|
83 |
def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
|
84 |
+
use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1,
|
85 |
+
save_every_n_steps=-1, validate_every_n_steps=-1,
|
86 |
+
validation_prompts=[], sample_positive_prompts=[], sample_negative_prompts=[]):
|
87 |
+
|
88 |
+
diffuser = None
|
89 |
+
loss = None
|
90 |
+
optimizer = None
|
91 |
+
finetuner = None
|
92 |
+
negative_latents = None
|
93 |
+
neutral_latents = None
|
94 |
+
positive_latents = None
|
95 |
|
96 |
nsteps = 50
|
97 |
+
print(f"using img_size of {img_size}")
|
98 |
+
diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path, native_img_size=img_size).to('cuda')
|
99 |
+
logger = SummaryWriter(log_dir=f"logs/{os.path.splitext(os.path.basename(save_path))[0]}")
|
100 |
|
101 |
memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
|
102 |
use_gradient_checkpointing=use_gradient_checkpointing )
|
|
|
120 |
pbar = tqdm(range(iterations))
|
121 |
|
122 |
with torch.no_grad():
|
123 |
+
neutral_text_embeddings = diffuser.get_cond_and_uncond_embeddings([''], n_imgs=1)
|
124 |
+
positive_text_embeddings = diffuser.get_cond_and_uncond_embeddings([prompt], n_imgs=1)
|
125 |
+
validation_embeddings = diffuser.get_cond_and_uncond_embeddings(validation_prompts, n_imgs=1)
|
126 |
+
sample_embeddings = diffuser.get_cond_and_uncond_embeddings(sample_positive_prompts, sample_negative_prompts, n_imgs=1)
|
127 |
|
128 |
+
#if use_amp:
|
129 |
+
# diffuser.vae = diffuser.vae.to(diffuser.vae.device, dtype=torch.float16)
|
|
|
130 |
|
131 |
+
#del diffuser.text_encoder
|
132 |
+
#del diffuser.tokenizer
|
133 |
|
134 |
+
torch.cuda.empty_cache()
|
135 |
|
136 |
if seed == -1:
|
137 |
seed = random.randint(0, 2 ** 30)
|
|
|
140 |
prev_losses = []
|
141 |
start_loss = None
|
142 |
max_prev_loss_count = 10
|
143 |
+
try:
|
144 |
+
for i in pbar:
|
145 |
+
if training_should_cancel:
|
146 |
+
print("received cancellation request")
|
147 |
+
return None
|
148 |
|
149 |
+
with torch.no_grad():
|
150 |
+
optimizer.zero_grad()
|
151 |
|
152 |
+
iteration = torch.randint(1, nsteps - 1, (1,)).item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
+
with finetuner:
|
155 |
+
diffused_latents = get_diffused_latents(diffuser, nsteps, positive_text_embeddings, iteration, use_amp)
|
|
|
156 |
|
157 |
+
iteration = int(iteration / nsteps * 1000)
|
|
|
|
|
158 |
|
159 |
+
with autocast(enabled=use_amp):
|
160 |
+
positive_latents = diffuser.predict_noise(iteration, diffused_latents, positive_text_embeddings, guidance_scale=1)
|
161 |
+
neutral_latents = diffuser.predict_noise(iteration, diffused_latents, neutral_text_embeddings, guidance_scale=1)
|
162 |
|
163 |
+
with finetuner:
|
164 |
+
with autocast(enabled=use_amp):
|
165 |
+
negative_latents = diffuser.predict_noise(iteration, diffused_latents, positive_text_embeddings, guidance_scale=1)
|
|
|
166 |
|
167 |
+
positive_latents.requires_grad = False
|
168 |
+
neutral_latents.requires_grad = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
+
# loss = criteria(e_n, e_0) works the best try 5000 epochs
|
171 |
+
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
|
172 |
+
memory_efficiency_wrapper.step(optimizer, loss)
|
173 |
+
optimizer.zero_grad()
|
174 |
|
175 |
+
logger.add_scalar("loss", loss.item(), global_step=i)
|
176 |
+
|
177 |
+
# print moving average loss
|
178 |
+
prev_losses.append(loss.detach().clone())
|
179 |
+
if len(prev_losses) > max_prev_loss_count:
|
180 |
+
prev_losses.pop(0)
|
181 |
+
if start_loss is None:
|
182 |
+
start_loss = prev_losses[-1]
|
183 |
+
if len(prev_losses) >= max_prev_loss_count:
|
184 |
+
moving_average_loss = sum(prev_losses) / len(prev_losses)
|
185 |
+
print(
|
186 |
+
f"step {i}: loss={loss.item()} (avg={moving_average_loss.item()}, start ∆={(moving_average_loss - start_loss).item()}")
|
187 |
+
else:
|
188 |
+
print(f"step {i}: loss={loss.item()}")
|
189 |
+
|
190 |
+
if save_every_n_steps > 0 and ((i+1) % save_every_n_steps) == 0:
|
191 |
+
torch.save(finetuner.state_dict(), save_path + f"__step_{i+1}.pt")
|
192 |
+
if validate_every_n_steps > 0 and ((i+1) % validate_every_n_steps) == 0:
|
193 |
+
validate(diffuser, finetuner,
|
194 |
+
validation_embeddings=validation_embeddings,
|
195 |
+
sample_embeddings=sample_embeddings,
|
196 |
+
neutral_embeddings=neutral_text_embeddings,
|
197 |
+
logger=logger, use_amp=False, global_step=i)
|
198 |
+
torch.save(finetuner.state_dict(), save_path)
|
199 |
+
return save_path
|
200 |
+
finally:
|
201 |
+
del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents
|
202 |
+
torch.cuda.empty_cache()
|
203 |
+
|
204 |
+
|
205 |
+
def get_diffused_latents(diffuser, nsteps, text_embeddings, end_iteration, use_amp):
|
206 |
+
diffuser.set_scheduler_timesteps(nsteps)
|
207 |
+
latents = diffuser.get_initial_latents(1, n_prompts=1)
|
208 |
+
latents_steps, _ = diffuser.diffusion(
|
209 |
+
latents,
|
210 |
+
text_embeddings,
|
211 |
+
start_iteration=0,
|
212 |
+
end_iteration=end_iteration,
|
213 |
+
guidance_scale=3,
|
214 |
+
show_progress=False,
|
215 |
+
use_amp=use_amp
|
216 |
+
)
|
217 |
+
# because return_latents is not passed to diffuser.diffusion(), latents_steps should have only 1 entry
|
218 |
+
# but we take the "last" (-1) entry because paranoia
|
219 |
+
diffused_latents = latents_steps[-1]
|
220 |
+
diffuser.set_scheduler_timesteps(1000)
|
221 |
+
del latents_steps, latents
|
222 |
+
return diffused_latents
|
223 |
|
|
|
224 |
|
|
|
225 |
if __name__ == '__main__':
|
226 |
|
227 |
import argparse
|