multimodalart HF staff commited on
Commit
a47c17f
·
1 Parent(s): 74e4319

Add v2-768

Browse files
Files changed (2) hide show
  1. app.py +4 -5
  2. train_dreambooth.py +3 -3
app.py CHANGED
@@ -58,8 +58,8 @@ def swap_base_model(selected_model):
58
  global model_to_load
59
  if(selected_model == "v1-5"):
60
  model_to_load = model_v1
61
- #elif(selected_model == "v2-768"):
62
- # model_to_load = model_v2
63
  else:
64
  model_to_load = model_v2_512
65
 
@@ -171,8 +171,7 @@ def train(*inputs):
171
  Training_Steps=1400
172
 
173
  stptxt = int((Training_Steps*Train_text_encoder_for)/100)
174
- #gradient_checkpointing = False if which_model == "v1-5" else True
175
- gradient_checkpointing=False
176
  resolution = 512 if which_model != "v2-768" else 768
177
  cache_latents = True if which_model != "v1-5" else False
178
  if (type_of_thing == "object" or type_of_thing == "style" or (type_of_thing == "person" and not experimental_face_improvement)):
@@ -445,7 +444,7 @@ with gr.Blocks(css=css) as demo:
445
 
446
  with gr.Row() as what_are_you_training:
447
  type_of_thing = gr.Dropdown(label="What would you like to train?", choices=["object", "person", "style"], value="object", interactive=True)
448
- base_model_to_use = gr.Dropdown(label="Which base model would you like to use?", choices=["v1-5", "v2-512"], value="v1-5", interactive=True)
449
 
450
  #Very hacky approach to emulate dynamically created Gradio components
451
  with gr.Row() as upload_your_concept:
 
58
  global model_to_load
59
  if(selected_model == "v1-5"):
60
  model_to_load = model_v1
61
+ elif(selected_model == "v2-768"):
62
+ model_to_load = model_v2
63
  else:
64
  model_to_load = model_v2_512
65
 
 
171
  Training_Steps=1400
172
 
173
  stptxt = int((Training_Steps*Train_text_encoder_for)/100)
174
+ gradient_checkpointing = False if which_model == "v1-5" else True
 
175
  resolution = 512 if which_model != "v2-768" else 768
176
  cache_latents = True if which_model != "v1-5" else False
177
  if (type_of_thing == "object" or type_of_thing == "style" or (type_of_thing == "person" and not experimental_face_improvement)):
 
444
 
445
  with gr.Row() as what_are_you_training:
446
  type_of_thing = gr.Dropdown(label="What would you like to train?", choices=["object", "person", "style"], value="object", interactive=True)
447
+ base_model_to_use = gr.Dropdown(label="Which base model would you like to use?", choices=["v1-5", "v2-512", "v2-768"], value="v1-5", interactive=True)
448
 
449
  #Very hacky approach to emulate dynamically created Gradio components
450
  with gr.Row() as upload_your_concept:
train_dreambooth.py CHANGED
@@ -710,10 +710,10 @@ def run_training(args_imported):
710
  # Convert images to latent space
711
  with torch.no_grad():
712
  if args.cache_latents:
713
- latents = batch[0][0]
714
  else:
715
- latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
716
- latents = latents * 0.18215
717
 
718
  # Sample noise that we'll add to the latents
719
  noise = torch.randn_like(latents)
 
710
  # Convert images to latent space
711
  with torch.no_grad():
712
  if args.cache_latents:
713
+ latents_dist = batch[0][0]
714
  else:
715
+ latents_dist = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist
716
+ latents = latents_dist.sample() * 0.18215
717
 
718
  # Sample noise that we'll add to the latents
719
  noise = torch.randn_like(latents)