root commited on
Commit
f052712
1 Parent(s): ca20311

add altdiffusion

Browse files
Files changed (2) hide show
  1. app.py +12 -4
  2. train_dreambooth.py +38 -14
app.py CHANGED
@@ -13,7 +13,7 @@ import zipfile
13
  import tarfile
14
  import urllib.parse
15
  import gc
16
- from diffusers import StableDiffusionPipeline
17
  from huggingface_hub import snapshot_download
18
 
19
 
@@ -34,6 +34,8 @@ if(is_gpu_associated):
34
  model_v1 = snapshot_download(repo_id="multimodalart/sd-fine-tunable")
35
  model_v2 = snapshot_download(repo_id="stabilityai/stable-diffusion-2")
36
  model_v2_512 = snapshot_download(repo_id="stabilityai/stable-diffusion-2-base")
 
 
37
  safety_checker = snapshot_download(repo_id="multimodalart/sd-sc")
38
  model_to_load = model_v1
39
 
@@ -69,6 +71,10 @@ def swap_base_model(selected_model):
69
  model_to_load = model_v1
70
  elif(selected_model == "v2-768"):
71
  model_to_load = model_v2
 
 
 
 
72
  else:
73
  model_to_load = model_v2_512
74
 
@@ -288,11 +294,13 @@ def train(*inputs):
288
  pipe_is_set = False
289
  def generate(prompt, steps):
290
  torch.cuda.empty_cache()
291
- from diffusers import StableDiffusionPipeline
 
292
  global pipe_is_set
293
  if(not pipe_is_set):
294
  global pipe
295
- pipe = StableDiffusionPipeline.from_pretrained("./output_model", torch_dtype=torch.float16)
 
296
  pipe = pipe.to("cuda")
297
  pipe_is_set = True
298
 
@@ -477,7 +485,7 @@ with gr.Blocks(css=css) as demo:
477
 
478
  with gr.Row() as what_are_you_training:
479
  type_of_thing = gr.Dropdown(label="What would you like to train?", choices=["object", "person", "style"], value="object", interactive=True)
480
- base_model_to_use = gr.Dropdown(label="Which base model would you like to use?", choices=["v1-5", "v2-512", "v2-768"], value="v1-5", interactive=True)
481
 
482
  #Very hacky approach to emulate dynamically created Gradio components
483
  with gr.Row() as upload_your_concept:
 
13
  import tarfile
14
  import urllib.parse
15
  import gc
16
+ # from diffusers import StableDiffusionPipeline
17
  from huggingface_hub import snapshot_download
18
 
19
 
 
34
  model_v1 = snapshot_download(repo_id="multimodalart/sd-fine-tunable")
35
  model_v2 = snapshot_download(repo_id="stabilityai/stable-diffusion-2")
36
  model_v2_512 = snapshot_download(repo_id="stabilityai/stable-diffusion-2-base")
37
+ model_alt = snapshot_download(repo_id="BAAI/AltDiffusion")
38
+ model_alt_m9 = snapshot_download(repo_id="BAAI/AltDiffusion-m9")
39
  safety_checker = snapshot_download(repo_id="multimodalart/sd-sc")
40
  model_to_load = model_v1
41
 
 
71
  model_to_load = model_v1
72
  elif(selected_model == "v2-768"):
73
  model_to_load = model_v2
74
+ elif(selected_model == "alt"):
75
+ model_to_load = model_alt
76
+ elif(selected_model == "alt_m9"):
77
+ model_to_load = model_alt_m9
78
  else:
79
  model_to_load = model_v2_512
80
 
 
294
  pipe_is_set = False
295
  def generate(prompt, steps):
296
  torch.cuda.empty_cache()
297
+ # from diffusers import StableDiffusionPipeline
298
+ from diffusers import DiffusionPipeline
299
  global pipe_is_set
300
  if(not pipe_is_set):
301
  global pipe
302
+ # pipe = StableDiffusionPipeline.from_pretrained("./output_model", torch_dtype=torch.float16)
303
+ pipe = DiffusionPipeline.from_pretrained("./output_model", torch_dtype=torch.float16)
304
  pipe = pipe.to("cuda")
305
  pipe_is_set = True
306
 
 
485
 
486
  with gr.Row() as what_are_you_training:
487
  type_of_thing = gr.Dropdown(label="What would you like to train?", choices=["object", "person", "style"], value="object", interactive=True)
488
+ base_model_to_use = gr.Dropdown(label="Which base model would you like to use?", choices=["v1-5", "v2-512", "v2-768", "alt", "alt_m9"], value="alt_m9", interactive=True)
489
 
490
  #Very hacky approach to emulate dynamically created Gradio components
491
  with gr.Row() as upload_your_concept:
train_dreambooth.py CHANGED
@@ -17,17 +17,34 @@ from torch.utils.data import Dataset
17
  from accelerate import Accelerator
18
  from accelerate.logging import get_logger
19
  from accelerate.utils import set_seed
20
- from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
21
  from diffusers.optimization import get_scheduler
22
  from huggingface_hub import HfFolder, Repository, whoami
23
  from PIL import Image
24
  from torchvision import transforms
25
  from tqdm.auto import tqdm
26
- from transformers import CLIPTextModel, CLIPTokenizer
27
 
28
 
29
  logger = get_logger(__name__)
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def parse_args():
33
  parser = argparse.ArgumentParser(description="Simple example of a training script.")
@@ -471,7 +488,7 @@ def run_training(args_imported):
471
 
472
  if cur_class_images < args.num_class_images:
473
  torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
474
- pipeline = StableDiffusionPipeline.from_pretrained(
475
  args.pretrained_model_name_or_path, torch_dtype=torch_dtype
476
  )
477
  pipeline.set_progress_bar_config(disable=True)
@@ -517,20 +534,27 @@ def run_training(args_imported):
517
 
518
  # Load the tokenizer
519
  if args.tokenizer_name:
520
- tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
521
  elif args.pretrained_model_name_or_path:
522
- tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
 
 
 
523
 
524
  # Load models and create wrapper for stable diffusion
525
  if args.train_only_unet:
526
  if os.path.exists(str(args.output_dir+"/text_encoder_trained")):
527
- text_encoder = CLIPTextModel.from_pretrained(args.output_dir, subfolder="text_encoder_trained")
 
528
  elif os.path.exists(str(args.output_dir+"/text_encoder")):
529
- text_encoder = CLIPTextModel.from_pretrained(args.output_dir, subfolder="text_encoder")
 
530
  else:
531
- text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
 
532
  else:
533
- text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
 
534
  vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
535
  unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
536
 
@@ -796,7 +820,7 @@ def run_training(args_imported):
796
  if os.path.exists(frz_dir):
797
  subprocess.call('rm -r '+ frz_dir, shell=True)
798
  os.mkdir(frz_dir)
799
- pipeline = StableDiffusionPipeline.from_pretrained(
800
  args.pretrained_model_name_or_path,
801
  unet=accelerator.unwrap_model(unet),
802
  text_encoder=accelerator.unwrap_model(text_encoder),
@@ -816,7 +840,7 @@ def run_training(args_imported):
816
  print(" SAVING CHECKPOINT: "+args.Session_dir+"/"+inst+".ckpt")
817
  # Create the pipeline using the trained modules and save it.
818
  if accelerator.is_main_process:
819
- pipeline = StableDiffusionPipeline.from_pretrained(
820
  args.pretrained_model_name_or_path,
821
  unet=accelerator.unwrap_model(unet),
822
  text_encoder=accelerator.unwrap_model(text_encoder),
@@ -839,7 +863,7 @@ def run_training(args_imported):
839
  txt_dir=args.output_dir + "/text_encoder_trained"
840
  if not os.path.exists(txt_dir):
841
  os.mkdir(txt_dir)
842
- pipeline = StableDiffusionPipeline.from_pretrained(
843
  args.pretrained_model_name_or_path,
844
  unet=accelerator.unwrap_model(unet),
845
  text_encoder=accelerator.unwrap_model(text_encoder),
@@ -847,7 +871,7 @@ def run_training(args_imported):
847
  pipeline.text_encoder.save_pretrained(txt_dir)
848
 
849
  elif args.train_only_unet:
850
- pipeline = StableDiffusionPipeline.from_pretrained(
851
  args.pretrained_model_name_or_path,
852
  unet=accelerator.unwrap_model(unet),
853
  text_encoder=accelerator.unwrap_model(text_encoder),
@@ -857,7 +881,7 @@ def run_training(args_imported):
857
  subprocess.call('rm -r '+txt_dir, shell=True)
858
 
859
  else:
860
- pipeline = StableDiffusionPipeline.from_pretrained(
861
  args.pretrained_model_name_or_path,
862
  unet=accelerator.unwrap_model(unet),
863
  text_encoder=accelerator.unwrap_model(text_encoder),
 
17
  from accelerate import Accelerator
18
  from accelerate.logging import get_logger
19
  from accelerate.utils import set_seed
20
+ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
21
  from diffusers.optimization import get_scheduler
22
  from huggingface_hub import HfFolder, Repository, whoami
23
  from PIL import Image
24
  from torchvision import transforms
25
  from tqdm.auto import tqdm
26
+ from transformers import AutoTokenizer, PretrainedConfig
27
 
28
 
29
  logger = get_logger(__name__)
30
 
31
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):
32
+ text_encoder_config = PretrainedConfig.from_pretrained(
33
+ pretrained_model_name_or_path,
34
+ subfolder="text_encoder",
35
+ )
36
+ model_class = text_encoder_config.architectures[0]
37
+
38
+ if model_class == "CLIPTextModel":
39
+ from transformers import CLIPTextModel
40
+
41
+ return CLIPTextModel
42
+ elif model_class == "RobertaSeriesModelWithTransformation":
43
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
44
+
45
+ return RobertaSeriesModelWithTransformation
46
+ else:
47
+ raise ValueError(f"{model_class} is not supported.")
48
 
49
  def parse_args():
50
  parser = argparse.ArgumentParser(description="Simple example of a training script.")
 
488
 
489
  if cur_class_images < args.num_class_images:
490
  torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
491
+ pipeline = DiffusionPipeline.from_pretrained(
492
  args.pretrained_model_name_or_path, torch_dtype=torch_dtype
493
  )
494
  pipeline.set_progress_bar_config(disable=True)
 
534
 
535
  # Load the tokenizer
536
  if args.tokenizer_name:
537
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=False)
538
  elif args.pretrained_model_name_or_path:
539
+ tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", use_fast=False)
540
+
541
+ # support for Altdiffusion
542
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path)
543
 
544
  # Load models and create wrapper for stable diffusion
545
  if args.train_only_unet:
546
  if os.path.exists(str(args.output_dir+"/text_encoder_trained")):
547
+ # text_encoder = CLIPTextModel.from_pretrained(args.output_dir, subfolder="text_encoder_trained")
548
+ text_encoder = text_encoder_cls.from_pretrained(args.output_dir, subfolder="text_encoder_trained")
549
  elif os.path.exists(str(args.output_dir+"/text_encoder")):
550
+ # text_encoder = CLIPTextModel.from_pretrained(args.output_dir, subfolder="text_encoder")
551
+ text_encoder = text_encoder_cls.from_pretrained(args.output_dir, subfolder="text_encoder")
552
  else:
553
+ # text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
554
+ text_encoder = text_encoder_cls.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
555
  else:
556
+ # text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
557
+ text_encoder = text_encoder_cls.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
558
  vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
559
  unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
560
 
 
820
  if os.path.exists(frz_dir):
821
  subprocess.call('rm -r '+ frz_dir, shell=True)
822
  os.mkdir(frz_dir)
823
+ pipeline = DiffusionPipeline.from_pretrained(
824
  args.pretrained_model_name_or_path,
825
  unet=accelerator.unwrap_model(unet),
826
  text_encoder=accelerator.unwrap_model(text_encoder),
 
840
  print(" SAVING CHECKPOINT: "+args.Session_dir+"/"+inst+".ckpt")
841
  # Create the pipeline using the trained modules and save it.
842
  if accelerator.is_main_process:
843
+ pipeline = DiffusionPipeline.from_pretrained(
844
  args.pretrained_model_name_or_path,
845
  unet=accelerator.unwrap_model(unet),
846
  text_encoder=accelerator.unwrap_model(text_encoder),
 
863
  txt_dir=args.output_dir + "/text_encoder_trained"
864
  if not os.path.exists(txt_dir):
865
  os.mkdir(txt_dir)
866
+ pipeline = DiffusionPipeline.from_pretrained(
867
  args.pretrained_model_name_or_path,
868
  unet=accelerator.unwrap_model(unet),
869
  text_encoder=accelerator.unwrap_model(text_encoder),
 
871
  pipeline.text_encoder.save_pretrained(txt_dir)
872
 
873
  elif args.train_only_unet:
874
+ pipeline = DiffusionPipeline.from_pretrained(
875
  args.pretrained_model_name_or_path,
876
  unet=accelerator.unwrap_model(unet),
877
  text_encoder=accelerator.unwrap_model(text_encoder),
 
881
  subprocess.call('rm -r '+txt_dir, shell=True)
882
 
883
  else:
884
+ pipeline = DiffusionPipeline.from_pretrained(
885
  args.pretrained_model_name_or_path,
886
  unet=accelerator.unwrap_model(unet),
887
  text_encoder=accelerator.unwrap_model(text_encoder),