Damian Stewart commited on
Commit
50b9662
1 Parent(s): 6067469

add tensorboard, validation, sample output

Browse files
Files changed (5) hide show
  1. StableDiffuser.py +32 -16
  2. app.py +111 -38
  3. memory_efficiency.py +1 -1
  4. requirements.txt +1 -0
  5. train.py +162 -57
StableDiffuser.py CHANGED
@@ -36,11 +36,13 @@ class StableDiffuser(torch.nn.Module):
36
  def __init__(self,
37
  scheduler='LMS',
38
  keep_pipeline=False,
 
39
  repo_id_or_path="CompVis/stable-diffusion-v1-4"):
40
 
41
  super().__init__()
42
 
43
  self.pipeline = StableDiffusionPipeline.from_pretrained(repo_id_or_path)
 
44
 
45
  self.vae = self.pipeline.vae
46
  self.unet = self.pipeline.unet
@@ -60,8 +62,10 @@ class StableDiffuser(torch.nn.Module):
60
  if not keep_pipeline:
61
  del self.pipeline
62
 
63
- def get_noise(self, batch_size, width, height, generator=None):
64
  param = list(self.parameters())[0]
 
 
65
  return torch.randn(
66
  (batch_size, self.unet.config.in_channels, width // 8, height // 8),
67
  generator=generator).type(param.dtype).to(param.device)
@@ -95,16 +99,20 @@ class StableDiffuser(torch.nn.Module):
95
  def set_scheduler_timesteps(self, n_steps):
96
  self.scheduler.set_timesteps(n_steps, device=self.unet.device)
97
 
98
- def get_initial_latents(self, n_imgs, height, width, n_prompts, generator=None):
 
 
99
  noise = self.get_noise(n_imgs, height, width, generator=generator).repeat(n_prompts, 1, 1, 1)
100
  latents = noise * self.scheduler.init_noise_sigma
101
  return latents
102
 
103
- def get_text_embeddings(self, prompts, negative_prompts=None, n_imgs=1):
104
  text_tokens = self.text_tokenize(prompts)
105
  text_embeddings = self.text_encode(text_tokens)
106
  if negative_prompts is None:
107
- negative_prompts = [""] * len(prompts)
 
 
108
  unconditional_tokens = self.text_tokenize(negative_prompts)
109
  unconditional_embeddings = self.text_encode(unconditional_tokens)
110
  text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]).repeat_interleave(n_imgs, dim=0)
@@ -136,12 +144,12 @@ class StableDiffuser(torch.nn.Module):
136
  @torch.no_grad()
137
  def diffusion(self,
138
  latents,
139
- text_embeddings,
140
  end_iteration=1000,
141
  start_iteration=0,
142
  return_steps=False,
143
  pred_x0=False,
144
- trace_args=None,
145
  show_progress=True,
146
  use_amp=False,
147
  **kwargs):
@@ -159,7 +167,7 @@ class StableDiffuser(torch.nn.Module):
159
  noise_pred = self.predict_noise(
160
  iteration,
161
  latents,
162
- text_embeddings,
163
  **kwargs)
164
 
165
  # compute the previous noisy sample x_t -> x_t-1
@@ -182,30 +190,38 @@ class StableDiffuser(torch.nn.Module):
182
 
183
  @torch.no_grad()
184
  def __call__(self,
185
- prompts,
186
- negative_prompts,
187
- width=512,
188
- height=512,
 
189
  n_steps=50,
190
  n_imgs=1,
191
  end_iteration=None,
192
  generator=None,
 
193
  **kwargs
194
  ):
195
 
196
  assert 0 <= n_steps <= 1000
197
 
198
- if not isinstance(prompts, list):
199
- prompts = [prompts]
 
 
 
 
 
200
 
201
  self.set_scheduler_timesteps(n_steps)
202
- latents = self.get_initial_latents(n_imgs, height, width, len(prompts), generator=generator)
203
- text_embeddings = self.get_text_embeddings(prompts,negative_prompts,n_imgs=n_imgs)
204
  end_iteration = end_iteration or n_steps
205
  latents_steps, trace_steps = self.diffusion(
206
  latents,
207
- text_embeddings,
208
  end_iteration=end_iteration,
 
209
  **kwargs
210
  )
211
 
 
36
  def __init__(self,
37
  scheduler='LMS',
38
  keep_pipeline=False,
39
+ native_img_size=512,
40
  repo_id_or_path="CompVis/stable-diffusion-v1-4"):
41
 
42
  super().__init__()
43
 
44
  self.pipeline = StableDiffusionPipeline.from_pretrained(repo_id_or_path)
45
+ self.native_image_size = native_img_size
46
 
47
  self.vae = self.pipeline.vae
48
  self.unet = self.pipeline.unet
 
62
  if not keep_pipeline:
63
  del self.pipeline
64
 
65
+ def get_noise(self, batch_size, width=None, height=None, generator=None):
66
  param = list(self.parameters())[0]
67
+ width = width or self.native_image_size
68
+ height = height or self.native_image_size
69
  return torch.randn(
70
  (batch_size, self.unet.config.in_channels, width // 8, height // 8),
71
  generator=generator).type(param.dtype).to(param.device)
 
99
  def set_scheduler_timesteps(self, n_steps):
100
  self.scheduler.set_timesteps(n_steps, device=self.unet.device)
101
 
102
+ def get_initial_latents(self, n_imgs, height=None, width=None, n_prompts=1, generator=None):
103
+ height = height or self.native_image_size
104
+ width = width or self.native_image_size
105
  noise = self.get_noise(n_imgs, height, width, generator=generator).repeat(n_prompts, 1, 1, 1)
106
  latents = noise * self.scheduler.init_noise_sigma
107
  return latents
108
 
109
+ def get_cond_and_uncond_embeddings(self, prompts, negative_prompts=None, n_imgs=1):
110
  text_tokens = self.text_tokenize(prompts)
111
  text_embeddings = self.text_encode(text_tokens)
112
  if negative_prompts is None:
113
+ negative_prompts = []
114
+ while len(negative_prompts) < len(prompts):
115
+ negative_prompts.append("")
116
  unconditional_tokens = self.text_tokenize(negative_prompts)
117
  unconditional_embeddings = self.text_encode(unconditional_tokens)
118
  text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]).repeat_interleave(n_imgs, dim=0)
 
144
  @torch.no_grad()
145
  def diffusion(self,
146
  latents,
147
+ uncond_and_cond_embeddings,
148
  end_iteration=1000,
149
  start_iteration=0,
150
  return_steps=False,
151
  pred_x0=False,
152
+ trace_args=None,
153
  show_progress=True,
154
  use_amp=False,
155
  **kwargs):
 
167
  noise_pred = self.predict_noise(
168
  iteration,
169
  latents,
170
+ uncond_and_cond_embeddings,
171
  **kwargs)
172
 
173
  # compute the previous noisy sample x_t -> x_t-1
 
190
 
191
  @torch.no_grad()
192
  def __call__(self,
193
+ prompts=None,
194
+ negative_prompts=None,
195
+ combined_embeddings=None, # uncond first, then cond
196
+ width=None,
197
+ height=None,
198
  n_steps=50,
199
  n_imgs=1,
200
  end_iteration=None,
201
  generator=None,
202
+ use_amp=False,
203
  **kwargs
204
  ):
205
 
206
  assert 0 <= n_steps <= 1000
207
 
208
+ if combined_embeddings is None:
209
+ assert prompts is not None, "missing prompts or combined_embeddings"
210
+ combined_embeddings = diffuser.get_cond_and_uncond_embeddings(prompts, negative_prompts, n_imgs=n_imgs)
211
+
212
+ width = width or self.native_image_size
213
+ height = height or self.native_image_size
214
+ num_prompts = combined_embeddings.shape[0] // 2
215
 
216
  self.set_scheduler_timesteps(n_steps)
217
+ latents = self.get_initial_latents(n_imgs, height, width, num_prompts, generator=generator)
218
+
219
  end_iteration = end_iteration or n_steps
220
  latents_steps, trace_steps = self.diffusion(
221
  latents,
222
+ combined_embeddings,
223
  end_iteration=end_iteration,
224
+ use_amp=use_amp,
225
  **kwargs
226
  )
227
 
app.py CHANGED
@@ -7,12 +7,13 @@ from diffusers.utils import is_xformers_available
7
  from finetuning import FineTunedModel
8
  from StableDiffuser import StableDiffuser
9
  from memory_efficiency import MemoryEfficiencyWrapper
10
- from train import train
11
 
12
  import os
13
 
 
14
  def populate_model_map():
15
- model_map = {}
16
  for model_file in os.listdir('models'):
17
  path = 'models/' + model_file
18
  if any([existing_path == path for existing_path in model_map.values()]):
@@ -28,6 +29,7 @@ SHARED_UI_WARNING = f'''## Attention - Training using the ESD-u method does not
28
  <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
29
  '''
30
 
 
31
 
32
  class Demo:
33
 
@@ -70,24 +72,11 @@ class Demo:
70
  self.negative_prompt_input_infr = gr.Text(
71
  label="Negative prompt"
72
  )
73
-
74
- with gr.Row():
75
-
76
- self.model_dropdown = gr.Dropdown(
77
- label="ESD Model",
78
- choices= list(model_map.keys()),
79
- value='Van Gogh',
80
- interactive=True
81
- )
82
- self.model_reload_button = gr.Button(
83
- value="🔄",
84
- interactive=True
85
- )
86
-
87
- self.seed_infr = gr.Number(
88
- label="Seed",
89
- value=42
90
- )
91
  self.img_width_infr = gr.Slider(
92
  label="Image width",
93
  minimum=256,
@@ -95,7 +84,6 @@ class Demo:
95
  value=512,
96
  step=64
97
  )
98
-
99
  self.img_height_infr = gr.Slider(
100
  label="Image height",
101
  minimum=256,
@@ -104,6 +92,18 @@ class Demo:
104
  step=64
105
  )
106
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  self.base_repo_id_or_path_input_infr = gr.Text(
108
  label="Base model",
109
  value="CompVis/stable-diffusion-v1-4",
@@ -131,14 +131,12 @@ class Demo:
131
  with gr.Tab("Train") as training_column:
132
 
133
  with gr.Row():
134
-
135
  self.explain_train= gr.Markdown(interactive=False,
136
  value='In this part you can erase any concept from Stable Diffusion. Enter a prompt for the concept or style you want to erase, and select ESD-x if you want to focus erasure on prompts that mention the concept explicitly. [NOTE: ESD-u is currently unavailable in this space. But you can duplicate the space and run it on GPU with VRAM >40GB for enabling ESD-u]. With default settings, it takes about 15 minutes to fine-tune the model; then you can try inference above or download the weights. The training code used here is slightly different than the code tested in the original paper. Code and details are at [github link](https://github.com/rohitgandikota/erasing).')
137
 
138
  with gr.Row():
139
 
140
  with gr.Column(scale=3):
141
-
142
  self.train_model_input = gr.Text(
143
  label="Model to Edit",
144
  value="CompVis/stable-diffusion-v1-4",
@@ -196,7 +194,7 @@ class Demo:
196
  )
197
  self.train_save_every_input = gr.Number(
198
  value=-1,
199
- label="Save every N steps",
200
  info="If >0, save the model throughout training at the given step interval."
201
  )
202
 
@@ -210,6 +208,28 @@ class Demo:
210
  self.train_use_gradient_checkpointing_input = gr.Checkbox(
211
  label="Gradient checkpointing", value=False)
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  with gr.Column(scale=1):
214
 
215
  self.train_status = gr.Button(value='', variant='primary', label='Status', interactive=False)
@@ -219,7 +239,7 @@ class Demo:
219
  )
220
 
221
  self.train_cancel_button = gr.Button(
222
- value="Cancel training"
223
  )
224
 
225
  self.download = gr.Files()
@@ -260,6 +280,7 @@ class Demo:
260
  value='', variant='primary', label='Status', interactive=False)
261
  self.export_button = gr.Button(
262
  value="Export")
 
263
 
264
  self.infr_button.click(self.inference, inputs = [
265
  self.prompt_input_infr,
@@ -292,10 +313,16 @@ class Demo:
292
  self.train_use_gradient_checkpointing_input,
293
  self.train_seed_input,
294
  self.train_save_every_input,
 
 
 
 
295
  ],
296
  outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
297
  )
298
- self.train_cancel_button.click(lambda x: print("cancel pressed"), cancels=[train_event])
 
 
299
 
300
  self.export_button.click(self.export, inputs = [
301
  self.model_dropdown_export,
@@ -303,23 +330,51 @@ class Demo:
303
  self.save_path_input_export,
304
  self.save_half_export
305
  ],
306
- outputs=[self.export_status]
307
  )
308
 
309
  def reload_models(self, model_dropdown):
310
  current_model_name = model_dropdown
311
  global model_map
312
- model_map = populate_model_map()
313
- return [gr.Dropdown.update(choices=list(model_map.keys()), value=current_model_name)]
 
 
 
 
314
 
315
  def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
316
  use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
317
  seed=-1, save_every=-1,
318
- pbar = gr.Progress(track_tqdm=True)):
319
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  if self.training:
321
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
322
 
 
 
323
  print(f"Training {repo_id_or_path} at {img_size} to remove '{prompt}'.")
324
  print(f" {train_method}, negative guidance {neg_guidance}, lr {lr}, {iterations} iterations.")
325
  print(f" {'✅' if use_gradient_checkpointing else '❌'} gradient checkpointing")
@@ -348,23 +403,38 @@ class Demo:
348
  break
349
  # repeat until a not-in-use path is found
350
 
 
 
 
 
 
 
 
351
  try:
352
  self.training = True
353
  self.train_cancel_button.update(interactive=True)
354
- train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
355
  use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing,
356
- seed=int(seed), save_every=int(save_every))
 
 
 
 
 
 
 
 
357
  finally:
358
  self.training = False
359
  self.train_cancel_button.update(interactive=False)
360
 
361
  torch.cuda.empty_cache()
362
 
363
- new_model_name = f'{os.path.basename(save_path)}'
364
- model_map[new_model_name] = save_path
365
 
366
  return [gr.update(interactive=True, value='Train'),
367
- gr.update(value=f'Done Training! Try your model ({new_model_name}) in the "Test" tab'),
368
  save_path,
369
  gr.Dropdown.update(choices=list(model_map.keys()), value=new_model_name)]
370
 
@@ -373,7 +443,7 @@ class Demo:
373
  checkpoint = torch.load(model_path)
374
  diffuser = StableDiffuser(scheduler='DDIM',
375
  keep_pipeline=True,
376
- repo_id_or_path=base_repo_id_or_path
377
  ).eval()
378
  finetuner = FineTunedModel.from_checkpoint(diffuser, checkpoint).eval()
379
  with finetuner:
@@ -381,7 +451,10 @@ class Demo:
381
  diffuser = diffuser.half()
382
  diffuser.pipeline.to('cpu', torch_dtype=torch.float16)
383
  diffuser.pipeline.save_pretrained(save_path)
384
- return [gr.update(value=f'Done! Your model is at {save_path}.')]
 
 
 
385
 
386
 
387
  def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
 
7
  from finetuning import FineTunedModel
8
  from StableDiffuser import StableDiffuser
9
  from memory_efficiency import MemoryEfficiencyWrapper
10
+ from train import train, training_should_cancel
11
 
12
  import os
13
 
14
+ model_map = {}
15
  def populate_model_map():
16
+ global model_map
17
  for model_file in os.listdir('models'):
18
  path = 'models/' + model_file
19
  if any([existing_path == path for existing_path in model_map.values()]):
 
29
  <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
30
  '''
31
 
32
+ # work around Gradio's weird threading
33
 
34
  class Demo:
35
 
 
72
  self.negative_prompt_input_infr = gr.Text(
73
  label="Negative prompt"
74
  )
75
+ self.seed_infr = gr.Number(
76
+ label="Seed",
77
+ value=42
78
+ )
79
+ with gr.Row(scale=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  self.img_width_infr = gr.Slider(
81
  label="Image width",
82
  minimum=256,
 
84
  value=512,
85
  step=64
86
  )
 
87
  self.img_height_infr = gr.Slider(
88
  label="Image height",
89
  minimum=256,
 
92
  step=64
93
  )
94
 
95
+ with gr.Row(scale=1):
96
+ self.model_dropdown = gr.Dropdown(
97
+ label="ESD Model",
98
+ choices= list(model_map.keys()),
99
+ value='Van Gogh',
100
+ interactive=True
101
+ )
102
+ self.model_reload_button = gr.Button(
103
+ value="🔄",
104
+ interactive=True
105
+ )
106
+
107
  self.base_repo_id_or_path_input_infr = gr.Text(
108
  label="Base model",
109
  value="CompVis/stable-diffusion-v1-4",
 
131
  with gr.Tab("Train") as training_column:
132
 
133
  with gr.Row():
 
134
  self.explain_train= gr.Markdown(interactive=False,
135
  value='In this part you can erase any concept from Stable Diffusion. Enter a prompt for the concept or style you want to erase, and select ESD-x if you want to focus erasure on prompts that mention the concept explicitly. [NOTE: ESD-u is currently unavailable in this space. But you can duplicate the space and run it on GPU with VRAM >40GB for enabling ESD-u]. With default settings, it takes about 15 minutes to fine-tune the model; then you can try inference above or download the weights. The training code used here is slightly different than the code tested in the original paper. Code and details are at [github link](https://github.com/rohitgandikota/erasing).')
136
 
137
  with gr.Row():
138
 
139
  with gr.Column(scale=3):
 
140
  self.train_model_input = gr.Text(
141
  label="Model to Edit",
142
  value="CompVis/stable-diffusion-v1-4",
 
194
  )
195
  self.train_save_every_input = gr.Number(
196
  value=-1,
197
+ label="Save Every N Steps",
198
  info="If >0, save the model throughout training at the given step interval."
199
  )
200
 
 
208
  self.train_use_gradient_checkpointing_input = gr.Checkbox(
209
  label="Gradient checkpointing", value=False)
210
 
211
+ self.train_validation_prompts = gr.TextArea(
212
+ label="Validation Prompts",
213
+ placeholder="Probably, you want to put the \"Prompt to Erase\" in here as the first entry...",
214
+ value='',
215
+ info="Prompts for producing validation graphs, one per line."
216
+ )
217
+ self.train_sample_positive_prompts = gr.TextArea(
218
+ label="Sample Prompts",
219
+ value='',
220
+ info="Positive prompts for generating sample images, one per line."
221
+ )
222
+ self.train_sample_negative_prompts = gr.TextArea(
223
+ label="Sample Negative Prompts",
224
+ value='',
225
+ info="Negative prompts for use when generating sample images. One for each positive prompt, or leave empty for none."
226
+ )
227
+ self.train_validate_every_n_steps = gr.Number(
228
+ label="Validate Every N Steps",
229
+ value=20,
230
+ info="Validation and sample generation will be run at intervals of this many steps"
231
+ )
232
+
233
  with gr.Column(scale=1):
234
 
235
  self.train_status = gr.Button(value='', variant='primary', label='Status', interactive=False)
 
239
  )
240
 
241
  self.train_cancel_button = gr.Button(
242
+ value="Cancel Training"
243
  )
244
 
245
  self.download = gr.Files()
 
280
  value='', variant='primary', label='Status', interactive=False)
281
  self.export_button = gr.Button(
282
  value="Export")
283
+ self.export_download = gr.Files()
284
 
285
  self.infr_button.click(self.inference, inputs = [
286
  self.prompt_input_infr,
 
313
  self.train_use_gradient_checkpointing_input,
314
  self.train_seed_input,
315
  self.train_save_every_input,
316
+ self.train_validation_prompts,
317
+ self.train_sample_positive_prompts,
318
+ self.train_sample_negative_prompts,
319
+ self.train_validate_every_n_steps
320
  ],
321
  outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
322
  )
323
+ self.train_cancel_button.click(self.cancel_training,
324
+ inputs=[],
325
+ outputs=[self.train_cancel_button])
326
 
327
  self.export_button.click(self.export, inputs = [
328
  self.model_dropdown_export,
 
330
  self.save_path_input_export,
331
  self.save_half_export
332
  ],
333
+ outputs=[self.export_button, self.export_status, self.export_download]
334
  )
335
 
336
  def reload_models(self, model_dropdown):
337
  current_model_name = model_dropdown
338
  global model_map
339
+ populate_model_map()
340
+ return [self.model_dropdown.update(choices=list(model_map.keys()), value=current_model_name)]
341
+
342
+ def cancel_training(self):
343
+ train.training_should_cancel = True
344
+ return [gr.update(value="Cancelling...", interactive=False)]
345
 
346
  def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
347
  use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
348
  seed=-1, save_every=-1,
349
+ validation_prompts: str=None, sample_positive_prompts: str=None, sample_negative_prompts: str=None, validate_every_n_steps=-1,
350
+ pbar=gr.Progress(track_tqdm=True)):
351
+ """
352
+
353
+ :param repo_id_or_path:
354
+ :param img_size:
355
+ :param prompt:
356
+ :param train_method:
357
+ :param neg_guidance:
358
+ :param iterations:
359
+ :param lr:
360
+ :param use_adamw8bit:
361
+ :param use_xformers:
362
+ :param use_amp:
363
+ :param use_gradient_checkpointing:
364
+ :param seed:
365
+ :param save_every:
366
+ :param validation_prompts: split on \n
367
+ :param sample_positive_prompts: split on \n
368
+ :param sample_negative_prompts: split on \n
369
+ :param validate_every_n_steps: split on \n
370
+ :param pbar:
371
+ :return:
372
+ """
373
  if self.training:
374
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
375
 
376
+ train.training_should_cancel = False
377
+
378
  print(f"Training {repo_id_or_path} at {img_size} to remove '{prompt}'.")
379
  print(f" {train_method}, negative guidance {neg_guidance}, lr {lr}, {iterations} iterations.")
380
  print(f" {'✅' if use_gradient_checkpointing else '❌'} gradient checkpointing")
 
403
  break
404
  # repeat until a not-in-use path is found
405
 
406
+ validation_prompts = [] if validation_prompts is None else validation_prompts.split('\n')
407
+ sample_positive_prompts = [] if sample_positive_prompts is None else sample_positive_prompts.split('\n')
408
+ sample_negative_prompts = [] if sample_negative_prompts is None else sample_negative_prompts.split('\n')
409
+ print(f"validation prompts: {validation_prompts}")
410
+ print(f"sample positive prompts: {sample_positive_prompts}")
411
+ print(f"sample negative prompts: {sample_negative_prompts}")
412
+
413
  try:
414
  self.training = True
415
  self.train_cancel_button.update(interactive=True)
416
+ save_path = train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
417
  use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing,
418
+ seed=int(seed), save_every_n_steps=int(save_every),
419
+ validate_every_n_steps=validate_every_n_steps, validation_prompts=validation_prompts,
420
+ sample_positive_prompts=sample_positive_prompts, sample_negative_prompts=sample_negative_prompts)
421
+ if save_path is None:
422
+ new_model_name = None
423
+ finished_message = "Training cancelled."
424
+ else:
425
+ new_model_name = f'{os.path.basename(save_path)}'
426
+ finished_message = f'Done Training! Try your model ({new_model_name}) in the "Test" tab'
427
  finally:
428
  self.training = False
429
  self.train_cancel_button.update(interactive=False)
430
 
431
  torch.cuda.empty_cache()
432
 
433
+ if new_model_name is not None:
434
+ model_map[new_model_name] = save_path
435
 
436
  return [gr.update(interactive=True, value='Train'),
437
+ gr.update(value=finished_message),
438
  save_path,
439
  gr.Dropdown.update(choices=list(model_map.keys()), value=new_model_name)]
440
 
 
443
  checkpoint = torch.load(model_path)
444
  diffuser = StableDiffuser(scheduler='DDIM',
445
  keep_pipeline=True,
446
+ repo_id_or_path=base_repo_id_or_path,
447
  ).eval()
448
  finetuner = FineTunedModel.from_checkpoint(diffuser, checkpoint).eval()
449
  with finetuner:
 
451
  diffuser = diffuser.half()
452
  diffuser.pipeline.to('cpu', torch_dtype=torch.float16)
453
  diffuser.pipeline.save_pretrained(save_path)
454
+
455
+ return [gr.update(interactive=True, value='Export'),
456
+ gr.update(value=f'Done Exporting!'),
457
+ save_path]
458
 
459
 
460
  def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
memory_efficiency.py CHANGED
@@ -44,7 +44,7 @@ class MemoryEfficiencyWrapper:
44
  print("xformers disabled via arg, using attention slicing instead")
45
  self.diffuser.unet.set_attention_slice("auto")
46
 
47
- self.diffuser.vae = self.diffuser.vae.to(self.diffuser.vae.device, dtype=torch.float16 if self.use_amp else torch.float32)
48
  self.diffuser.unet = self.diffuser.unet.to(self.diffuser.unet.device, dtype=torch.float32)
49
 
50
  try:
 
44
  print("xformers disabled via arg, using attention slicing instead")
45
  self.diffuser.unet.set_attention_slice("auto")
46
 
47
+ #self.diffuser.vae = self.diffuser.vae.to(self.diffuser.vae.device, dtype=torch.float16 if self.use_amp else torch.float32)
48
  self.diffuser.unet = self.diffuser.unet.to(self.diffuser.unet.device, dtype=torch.float32)
49
 
50
  try:
requirements.txt CHANGED
@@ -9,3 +9,4 @@ git+https://github.com/davidbau/baukit.git
9
  xformers
10
  bitsandbytes==0.38.1
11
  safetensors
 
 
9
  xformers
10
  bitsandbytes==0.38.1
11
  safetensors
12
+ tensorboard
train.py CHANGED
@@ -1,7 +1,10 @@
 
1
  import random
2
 
3
  from accelerate.utils import set_seed
 
4
  from torch.cuda.amp import autocast
 
5
 
6
  from StableDiffuser import StableDiffuser
7
  from finetuning import FineTunedModel
@@ -10,13 +13,90 @@ from tqdm import tqdm
10
 
11
  from isolate_rng import isolate_rng
12
  from memory_efficiency import MemoryEfficiencyWrapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
16
- use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1, save_every=-1):
 
 
 
 
 
 
 
 
 
 
17
 
18
  nsteps = 50
19
- diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
 
 
20
 
21
  memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
22
  use_gradient_checkpointing=use_gradient_checkpointing )
@@ -40,16 +120,18 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
40
  pbar = tqdm(range(iterations))
41
 
42
  with torch.no_grad():
43
- neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
44
- positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
 
 
45
 
46
- del diffuser.vae
47
- del diffuser.text_encoder
48
- del diffuser.tokenizer
49
 
50
- torch.cuda.empty_cache()
 
51
 
52
- print(f"using img_size of {img_size}")
53
 
54
  if seed == -1:
55
  seed = random.randint(0, 2 ** 30)
@@ -58,65 +140,88 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
58
  prev_losses = []
59
  start_loss = None
60
  max_prev_loss_count = 10
61
- for i in pbar:
62
- with torch.no_grad():
63
- diffuser.set_scheduler_timesteps(nsteps)
64
- optimizer.zero_grad()
 
65
 
66
- iteration = torch.randint(1, nsteps - 1, (1,)).item()
67
- latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)
68
 
69
- with finetuner:
70
- latents_steps, _ = diffuser.diffusion(
71
- latents,
72
- positive_text_embeddings,
73
- start_iteration=0,
74
- end_iteration=iteration,
75
- guidance_scale=3,
76
- show_progress=False,
77
- use_amp=use_amp
78
- )
79
-
80
- diffuser.set_scheduler_timesteps(1000)
81
- iteration = int(iteration / nsteps * 1000)
82
 
83
- with autocast(enabled=use_amp):
84
- positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
85
- neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
86
 
87
- with finetuner:
88
- with autocast(enabled=use_amp):
89
- negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
90
 
91
- positive_latents.requires_grad = False
92
- neutral_latents.requires_grad = False
 
93
 
94
- # loss = criteria(e_n, e_0) works the best try 5000 epochs
95
- loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
96
- memory_efficiency_wrapper.step(optimizer, loss)
97
- optimizer.zero_grad()
98
 
99
- # print moving average loss
100
- prev_losses.append(loss.detach().clone())
101
- if len(prev_losses) > max_prev_loss_count:
102
- prev_losses.pop(0)
103
- if start_loss is None:
104
- start_loss = prev_losses[-1]
105
- if len(prev_losses) >= max_prev_loss_count:
106
- moving_average_loss = sum(prev_losses) / len(prev_losses)
107
- print(
108
- f"step {i}: loss={loss.item()} (avg={moving_average_loss.item()}, start ∆={(moving_average_loss - start_loss).item()}")
109
- else:
110
- print(f"step {i}: loss={loss.item()}")
111
 
112
- if save_every > 0 and ((i % save_every) == (save_every-1)):
113
- torch.save(finetuner.state_dict(), save_path + f"__step_{i}.pt")
 
 
114
 
115
- torch.save(finetuner.state_dict(), save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents
118
 
119
- torch.cuda.empty_cache()
120
  if __name__ == '__main__':
121
 
122
  import argparse
 
1
+ import os.path
2
  import random
3
 
4
  from accelerate.utils import set_seed
5
+ from diffusers import StableDiffusionPipeline
6
  from torch.cuda.amp import autocast
7
+ from torchvision import transforms
8
 
9
  from StableDiffuser import StableDiffuser
10
  from finetuning import FineTunedModel
 
13
 
14
  from isolate_rng import isolate_rng
15
  from memory_efficiency import MemoryEfficiencyWrapper
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ training_should_cancel = False
19
+
20
+ def validate(diffuser: StableDiffuser, finetuner: FineTunedModel,
21
+ validation_embeddings: torch.FloatTensor,
22
+ neutral_embeddings: torch.FloatTensor,
23
+ sample_embeddings: torch.FloatTensor,
24
+ logger: SummaryWriter, use_amp: bool,
25
+ global_step: int,
26
+ validation_seed: int = 555,
27
+ ):
28
+ print("validating...")
29
+ with isolate_rng(include_cuda=True), torch.no_grad():
30
+ set_seed(validation_seed)
31
+ criteria = torch.nn.MSELoss()
32
+ negative_guidance = 1
33
+ val_count = 5
34
+
35
+ nsteps=50
36
+ num_validation_prompts = validation_embeddings.shape[0] // 2
37
+ for i in range(0, num_validation_prompts):
38
+ accumulated_loss = None
39
+ this_validation_embeddings = validation_embeddings[i*2:i*2+2]
40
+ for j in range(val_count):
41
+ iteration = random.randint(1, nsteps)
42
+ diffused_latents = get_diffused_latents(diffuser, nsteps, this_validation_embeddings, iteration, use_amp)
43
+
44
+ with autocast(enabled=use_amp):
45
+ positive_latents = diffuser.predict_noise(iteration, diffused_latents, this_validation_embeddings, guidance_scale=1)
46
+ neutral_latents = diffuser.predict_noise(iteration, diffused_latents, neutral_embeddings, guidance_scale=1)
47
 
48
+ with finetuner, autocast(enabled=use_amp):
49
+ negative_latents = diffuser.predict_noise(iteration, diffused_latents, this_validation_embeddings, guidance_scale=1)
50
+
51
+ loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
52
+ accumulated_loss = (accumulated_loss or 0) + loss.item()
53
+ logger.add_scalar(f"loss/val_{i}", accumulated_loss/val_count, global_step=global_step)
54
+
55
+ num_samples = sample_embeddings.shape[0] // 2
56
+ for i in range(0, num_samples):
57
+ print(f'making sample {i}...')
58
+ with finetuner:
59
+ pipeline = StableDiffusionPipeline(vae=diffuser.vae,
60
+ text_encoder=diffuser.text_encoder,
61
+ tokenizer=diffuser.tokenizer,
62
+ unet=diffuser.unet,
63
+ scheduler=diffuser.scheduler,
64
+ safety_checker=None,
65
+ feature_extractor=None,
66
+ requires_safety_checker=False)
67
+ images = pipeline(prompt_embeds=sample_embeddings[i*2+1:i*2+2], negative_prompt_embeds=sample_embeddings[i*2:i*2+1],
68
+ num_inference_steps=50)
69
+ image_tensor = transforms.ToTensor()(images.images[0])
70
+ logger.add_image(f"samples/{i}", img_tensor=image_tensor, global_step=global_step)
71
+
72
+ """
73
+ with finetuner, torch.cuda.amp.autocast(enabled=use_amp):
74
+ images = diffuser(
75
+ combined_embeddings=sample_embeddings[i*2:i*2+2],
76
+ n_steps=50
77
+ )
78
+ logger.add_images(f"samples/{i}", images)
79
+ """
80
+
81
+ torch.cuda.empty_cache()
82
 
83
  def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
84
+ use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1,
85
+ save_every_n_steps=-1, validate_every_n_steps=-1,
86
+ validation_prompts=[], sample_positive_prompts=[], sample_negative_prompts=[]):
87
+
88
+ diffuser = None
89
+ loss = None
90
+ optimizer = None
91
+ finetuner = None
92
+ negative_latents = None
93
+ neutral_latents = None
94
+ positive_latents = None
95
 
96
  nsteps = 50
97
+ print(f"using img_size of {img_size}")
98
+ diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path, native_img_size=img_size).to('cuda')
99
+ logger = SummaryWriter(log_dir=f"logs/{os.path.splitext(os.path.basename(save_path))[0]}")
100
 
101
  memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
102
  use_gradient_checkpointing=use_gradient_checkpointing )
 
120
  pbar = tqdm(range(iterations))
121
 
122
  with torch.no_grad():
123
+ neutral_text_embeddings = diffuser.get_cond_and_uncond_embeddings([''], n_imgs=1)
124
+ positive_text_embeddings = diffuser.get_cond_and_uncond_embeddings([prompt], n_imgs=1)
125
+ validation_embeddings = diffuser.get_cond_and_uncond_embeddings(validation_prompts, n_imgs=1)
126
+ sample_embeddings = diffuser.get_cond_and_uncond_embeddings(sample_positive_prompts, sample_negative_prompts, n_imgs=1)
127
 
128
+ #if use_amp:
129
+ # diffuser.vae = diffuser.vae.to(diffuser.vae.device, dtype=torch.float16)
 
130
 
131
+ #del diffuser.text_encoder
132
+ #del diffuser.tokenizer
133
 
134
+ torch.cuda.empty_cache()
135
 
136
  if seed == -1:
137
  seed = random.randint(0, 2 ** 30)
 
140
  prev_losses = []
141
  start_loss = None
142
  max_prev_loss_count = 10
143
+ try:
144
+ for i in pbar:
145
+ if training_should_cancel:
146
+ print("received cancellation request")
147
+ return None
148
 
149
+ with torch.no_grad():
150
+ optimizer.zero_grad()
151
 
152
+ iteration = torch.randint(1, nsteps - 1, (1,)).item()
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ with finetuner:
155
+ diffused_latents = get_diffused_latents(diffuser, nsteps, positive_text_embeddings, iteration, use_amp)
 
156
 
157
+ iteration = int(iteration / nsteps * 1000)
 
 
158
 
159
+ with autocast(enabled=use_amp):
160
+ positive_latents = diffuser.predict_noise(iteration, diffused_latents, positive_text_embeddings, guidance_scale=1)
161
+ neutral_latents = diffuser.predict_noise(iteration, diffused_latents, neutral_text_embeddings, guidance_scale=1)
162
 
163
+ with finetuner:
164
+ with autocast(enabled=use_amp):
165
+ negative_latents = diffuser.predict_noise(iteration, diffused_latents, positive_text_embeddings, guidance_scale=1)
 
166
 
167
+ positive_latents.requires_grad = False
168
+ neutral_latents.requires_grad = False
 
 
 
 
 
 
 
 
 
 
169
 
170
+ # loss = criteria(e_n, e_0) works the best try 5000 epochs
171
+ loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
172
+ memory_efficiency_wrapper.step(optimizer, loss)
173
+ optimizer.zero_grad()
174
 
175
+ logger.add_scalar("loss", loss.item(), global_step=i)
176
+
177
+ # print moving average loss
178
+ prev_losses.append(loss.detach().clone())
179
+ if len(prev_losses) > max_prev_loss_count:
180
+ prev_losses.pop(0)
181
+ if start_loss is None:
182
+ start_loss = prev_losses[-1]
183
+ if len(prev_losses) >= max_prev_loss_count:
184
+ moving_average_loss = sum(prev_losses) / len(prev_losses)
185
+ print(
186
+ f"step {i}: loss={loss.item()} (avg={moving_average_loss.item()}, start ∆={(moving_average_loss - start_loss).item()}")
187
+ else:
188
+ print(f"step {i}: loss={loss.item()}")
189
+
190
+ if save_every_n_steps > 0 and ((i+1) % save_every_n_steps) == 0:
191
+ torch.save(finetuner.state_dict(), save_path + f"__step_{i+1}.pt")
192
+ if validate_every_n_steps > 0 and ((i+1) % validate_every_n_steps) == 0:
193
+ validate(diffuser, finetuner,
194
+ validation_embeddings=validation_embeddings,
195
+ sample_embeddings=sample_embeddings,
196
+ neutral_embeddings=neutral_text_embeddings,
197
+ logger=logger, use_amp=False, global_step=i)
198
+ torch.save(finetuner.state_dict(), save_path)
199
+ return save_path
200
+ finally:
201
+ del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents
202
+ torch.cuda.empty_cache()
203
+
204
+
205
+ def get_diffused_latents(diffuser, nsteps, text_embeddings, end_iteration, use_amp):
206
+ diffuser.set_scheduler_timesteps(nsteps)
207
+ latents = diffuser.get_initial_latents(1, n_prompts=1)
208
+ latents_steps, _ = diffuser.diffusion(
209
+ latents,
210
+ text_embeddings,
211
+ start_iteration=0,
212
+ end_iteration=end_iteration,
213
+ guidance_scale=3,
214
+ show_progress=False,
215
+ use_amp=use_amp
216
+ )
217
+ # because return_latents is not passed to diffuser.diffusion(), latents_steps should have only 1 entry
218
+ # but we take the "last" (-1) entry because paranoia
219
+ diffused_latents = latents_steps[-1]
220
+ diffuser.set_scheduler_timesteps(1000)
221
+ del latents_steps, latents
222
+ return diffused_latents
223
 
 
224
 
 
225
  if __name__ == '__main__':
226
 
227
  import argparse