Spaces:
Configuration error
Configuration error
root
commited on
Commit
•
f052712
1
Parent(s):
ca20311
add altdiffusion
Browse files- app.py +12 -4
- train_dreambooth.py +38 -14
app.py
CHANGED
@@ -13,7 +13,7 @@ import zipfile
|
|
13 |
import tarfile
|
14 |
import urllib.parse
|
15 |
import gc
|
16 |
-
from diffusers import StableDiffusionPipeline
|
17 |
from huggingface_hub import snapshot_download
|
18 |
|
19 |
|
@@ -34,6 +34,8 @@ if(is_gpu_associated):
|
|
34 |
model_v1 = snapshot_download(repo_id="multimodalart/sd-fine-tunable")
|
35 |
model_v2 = snapshot_download(repo_id="stabilityai/stable-diffusion-2")
|
36 |
model_v2_512 = snapshot_download(repo_id="stabilityai/stable-diffusion-2-base")
|
|
|
|
|
37 |
safety_checker = snapshot_download(repo_id="multimodalart/sd-sc")
|
38 |
model_to_load = model_v1
|
39 |
|
@@ -69,6 +71,10 @@ def swap_base_model(selected_model):
|
|
69 |
model_to_load = model_v1
|
70 |
elif(selected_model == "v2-768"):
|
71 |
model_to_load = model_v2
|
|
|
|
|
|
|
|
|
72 |
else:
|
73 |
model_to_load = model_v2_512
|
74 |
|
@@ -288,11 +294,13 @@ def train(*inputs):
|
|
288 |
pipe_is_set = False
|
289 |
def generate(prompt, steps):
|
290 |
torch.cuda.empty_cache()
|
291 |
-
from diffusers import StableDiffusionPipeline
|
|
|
292 |
global pipe_is_set
|
293 |
if(not pipe_is_set):
|
294 |
global pipe
|
295 |
-
pipe = StableDiffusionPipeline.from_pretrained("./output_model", torch_dtype=torch.float16)
|
|
|
296 |
pipe = pipe.to("cuda")
|
297 |
pipe_is_set = True
|
298 |
|
@@ -477,7 +485,7 @@ with gr.Blocks(css=css) as demo:
|
|
477 |
|
478 |
with gr.Row() as what_are_you_training:
|
479 |
type_of_thing = gr.Dropdown(label="What would you like to train?", choices=["object", "person", "style"], value="object", interactive=True)
|
480 |
-
base_model_to_use = gr.Dropdown(label="Which base model would you like to use?", choices=["v1-5", "v2-512", "v2-768"], value="
|
481 |
|
482 |
#Very hacky approach to emulate dynamically created Gradio components
|
483 |
with gr.Row() as upload_your_concept:
|
|
|
13 |
import tarfile
|
14 |
import urllib.parse
|
15 |
import gc
|
16 |
+
# from diffusers import StableDiffusionPipeline
|
17 |
from huggingface_hub import snapshot_download
|
18 |
|
19 |
|
|
|
34 |
model_v1 = snapshot_download(repo_id="multimodalart/sd-fine-tunable")
|
35 |
model_v2 = snapshot_download(repo_id="stabilityai/stable-diffusion-2")
|
36 |
model_v2_512 = snapshot_download(repo_id="stabilityai/stable-diffusion-2-base")
|
37 |
+
model_alt = snapshot_download(repo_id="BAAI/AltDiffusion")
|
38 |
+
model_alt_m9 = snapshot_download(repo_id="BAAI/AltDiffusion-m9")
|
39 |
safety_checker = snapshot_download(repo_id="multimodalart/sd-sc")
|
40 |
model_to_load = model_v1
|
41 |
|
|
|
71 |
model_to_load = model_v1
|
72 |
elif(selected_model == "v2-768"):
|
73 |
model_to_load = model_v2
|
74 |
+
elif(selected_model == "alt"):
|
75 |
+
model_to_load = model_alt
|
76 |
+
elif(selected_model == "alt_m9"):
|
77 |
+
model_to_load = model_alt_m9
|
78 |
else:
|
79 |
model_to_load = model_v2_512
|
80 |
|
|
|
294 |
pipe_is_set = False
|
295 |
def generate(prompt, steps):
|
296 |
torch.cuda.empty_cache()
|
297 |
+
# from diffusers import StableDiffusionPipeline
|
298 |
+
from diffusers import DiffusionPipeline
|
299 |
global pipe_is_set
|
300 |
if(not pipe_is_set):
|
301 |
global pipe
|
302 |
+
# pipe = StableDiffusionPipeline.from_pretrained("./output_model", torch_dtype=torch.float16)
|
303 |
+
pipe = DiffusionPipeline.from_pretrained("./output_model", torch_dtype=torch.float16)
|
304 |
pipe = pipe.to("cuda")
|
305 |
pipe_is_set = True
|
306 |
|
|
|
485 |
|
486 |
with gr.Row() as what_are_you_training:
|
487 |
type_of_thing = gr.Dropdown(label="What would you like to train?", choices=["object", "person", "style"], value="object", interactive=True)
|
488 |
+
base_model_to_use = gr.Dropdown(label="Which base model would you like to use?", choices=["v1-5", "v2-512", "v2-768", "alt", "alt_m9"], value="alt_m9", interactive=True)
|
489 |
|
490 |
#Very hacky approach to emulate dynamically created Gradio components
|
491 |
with gr.Row() as upload_your_concept:
|
train_dreambooth.py
CHANGED
@@ -17,17 +17,34 @@ from torch.utils.data import Dataset
|
|
17 |
from accelerate import Accelerator
|
18 |
from accelerate.logging import get_logger
|
19 |
from accelerate.utils import set_seed
|
20 |
-
from diffusers import AutoencoderKL, DDPMScheduler,
|
21 |
from diffusers.optimization import get_scheduler
|
22 |
from huggingface_hub import HfFolder, Repository, whoami
|
23 |
from PIL import Image
|
24 |
from torchvision import transforms
|
25 |
from tqdm.auto import tqdm
|
26 |
-
from transformers import
|
27 |
|
28 |
|
29 |
logger = get_logger(__name__)
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
def parse_args():
|
33 |
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
@@ -471,7 +488,7 @@ def run_training(args_imported):
|
|
471 |
|
472 |
if cur_class_images < args.num_class_images:
|
473 |
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
474 |
-
pipeline =
|
475 |
args.pretrained_model_name_or_path, torch_dtype=torch_dtype
|
476 |
)
|
477 |
pipeline.set_progress_bar_config(disable=True)
|
@@ -517,20 +534,27 @@ def run_training(args_imported):
|
|
517 |
|
518 |
# Load the tokenizer
|
519 |
if args.tokenizer_name:
|
520 |
-
tokenizer =
|
521 |
elif args.pretrained_model_name_or_path:
|
522 |
-
tokenizer =
|
|
|
|
|
|
|
523 |
|
524 |
# Load models and create wrapper for stable diffusion
|
525 |
if args.train_only_unet:
|
526 |
if os.path.exists(str(args.output_dir+"/text_encoder_trained")):
|
527 |
-
text_encoder = CLIPTextModel.from_pretrained(args.output_dir, subfolder="text_encoder_trained")
|
|
|
528 |
elif os.path.exists(str(args.output_dir+"/text_encoder")):
|
529 |
-
text_encoder = CLIPTextModel.from_pretrained(args.output_dir, subfolder="text_encoder")
|
|
|
530 |
else:
|
531 |
-
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
|
|
532 |
else:
|
533 |
-
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
|
|
534 |
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
535 |
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
536 |
|
@@ -796,7 +820,7 @@ def run_training(args_imported):
|
|
796 |
if os.path.exists(frz_dir):
|
797 |
subprocess.call('rm -r '+ frz_dir, shell=True)
|
798 |
os.mkdir(frz_dir)
|
799 |
-
pipeline =
|
800 |
args.pretrained_model_name_or_path,
|
801 |
unet=accelerator.unwrap_model(unet),
|
802 |
text_encoder=accelerator.unwrap_model(text_encoder),
|
@@ -816,7 +840,7 @@ def run_training(args_imported):
|
|
816 |
print(" [1;32mSAVING CHECKPOINT: "+args.Session_dir+"/"+inst+".ckpt")
|
817 |
# Create the pipeline using the trained modules and save it.
|
818 |
if accelerator.is_main_process:
|
819 |
-
pipeline =
|
820 |
args.pretrained_model_name_or_path,
|
821 |
unet=accelerator.unwrap_model(unet),
|
822 |
text_encoder=accelerator.unwrap_model(text_encoder),
|
@@ -839,7 +863,7 @@ def run_training(args_imported):
|
|
839 |
txt_dir=args.output_dir + "/text_encoder_trained"
|
840 |
if not os.path.exists(txt_dir):
|
841 |
os.mkdir(txt_dir)
|
842 |
-
pipeline =
|
843 |
args.pretrained_model_name_or_path,
|
844 |
unet=accelerator.unwrap_model(unet),
|
845 |
text_encoder=accelerator.unwrap_model(text_encoder),
|
@@ -847,7 +871,7 @@ def run_training(args_imported):
|
|
847 |
pipeline.text_encoder.save_pretrained(txt_dir)
|
848 |
|
849 |
elif args.train_only_unet:
|
850 |
-
pipeline =
|
851 |
args.pretrained_model_name_or_path,
|
852 |
unet=accelerator.unwrap_model(unet),
|
853 |
text_encoder=accelerator.unwrap_model(text_encoder),
|
@@ -857,7 +881,7 @@ def run_training(args_imported):
|
|
857 |
subprocess.call('rm -r '+txt_dir, shell=True)
|
858 |
|
859 |
else:
|
860 |
-
pipeline =
|
861 |
args.pretrained_model_name_or_path,
|
862 |
unet=accelerator.unwrap_model(unet),
|
863 |
text_encoder=accelerator.unwrap_model(text_encoder),
|
|
|
17 |
from accelerate import Accelerator
|
18 |
from accelerate.logging import get_logger
|
19 |
from accelerate.utils import set_seed
|
20 |
+
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
21 |
from diffusers.optimization import get_scheduler
|
22 |
from huggingface_hub import HfFolder, Repository, whoami
|
23 |
from PIL import Image
|
24 |
from torchvision import transforms
|
25 |
from tqdm.auto import tqdm
|
26 |
+
from transformers import AutoTokenizer, PretrainedConfig
|
27 |
|
28 |
|
29 |
logger = get_logger(__name__)
|
30 |
|
31 |
+
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):
|
32 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
33 |
+
pretrained_model_name_or_path,
|
34 |
+
subfolder="text_encoder",
|
35 |
+
)
|
36 |
+
model_class = text_encoder_config.architectures[0]
|
37 |
+
|
38 |
+
if model_class == "CLIPTextModel":
|
39 |
+
from transformers import CLIPTextModel
|
40 |
+
|
41 |
+
return CLIPTextModel
|
42 |
+
elif model_class == "RobertaSeriesModelWithTransformation":
|
43 |
+
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
44 |
+
|
45 |
+
return RobertaSeriesModelWithTransformation
|
46 |
+
else:
|
47 |
+
raise ValueError(f"{model_class} is not supported.")
|
48 |
|
49 |
def parse_args():
|
50 |
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
|
|
488 |
|
489 |
if cur_class_images < args.num_class_images:
|
490 |
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
491 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
492 |
args.pretrained_model_name_or_path, torch_dtype=torch_dtype
|
493 |
)
|
494 |
pipeline.set_progress_bar_config(disable=True)
|
|
|
534 |
|
535 |
# Load the tokenizer
|
536 |
if args.tokenizer_name:
|
537 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=False)
|
538 |
elif args.pretrained_model_name_or_path:
|
539 |
+
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", use_fast=False)
|
540 |
+
|
541 |
+
# support for Altdiffusion
|
542 |
+
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path)
|
543 |
|
544 |
# Load models and create wrapper for stable diffusion
|
545 |
if args.train_only_unet:
|
546 |
if os.path.exists(str(args.output_dir+"/text_encoder_trained")):
|
547 |
+
# text_encoder = CLIPTextModel.from_pretrained(args.output_dir, subfolder="text_encoder_trained")
|
548 |
+
text_encoder = text_encoder_cls.from_pretrained(args.output_dir, subfolder="text_encoder_trained")
|
549 |
elif os.path.exists(str(args.output_dir+"/text_encoder")):
|
550 |
+
# text_encoder = CLIPTextModel.from_pretrained(args.output_dir, subfolder="text_encoder")
|
551 |
+
text_encoder = text_encoder_cls.from_pretrained(args.output_dir, subfolder="text_encoder")
|
552 |
else:
|
553 |
+
# text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
554 |
+
text_encoder = text_encoder_cls.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
555 |
else:
|
556 |
+
# text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
557 |
+
text_encoder = text_encoder_cls.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
558 |
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
559 |
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
560 |
|
|
|
820 |
if os.path.exists(frz_dir):
|
821 |
subprocess.call('rm -r '+ frz_dir, shell=True)
|
822 |
os.mkdir(frz_dir)
|
823 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
824 |
args.pretrained_model_name_or_path,
|
825 |
unet=accelerator.unwrap_model(unet),
|
826 |
text_encoder=accelerator.unwrap_model(text_encoder),
|
|
|
840 |
print(" [1;32mSAVING CHECKPOINT: "+args.Session_dir+"/"+inst+".ckpt")
|
841 |
# Create the pipeline using the trained modules and save it.
|
842 |
if accelerator.is_main_process:
|
843 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
844 |
args.pretrained_model_name_or_path,
|
845 |
unet=accelerator.unwrap_model(unet),
|
846 |
text_encoder=accelerator.unwrap_model(text_encoder),
|
|
|
863 |
txt_dir=args.output_dir + "/text_encoder_trained"
|
864 |
if not os.path.exists(txt_dir):
|
865 |
os.mkdir(txt_dir)
|
866 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
867 |
args.pretrained_model_name_or_path,
|
868 |
unet=accelerator.unwrap_model(unet),
|
869 |
text_encoder=accelerator.unwrap_model(text_encoder),
|
|
|
871 |
pipeline.text_encoder.save_pretrained(txt_dir)
|
872 |
|
873 |
elif args.train_only_unet:
|
874 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
875 |
args.pretrained_model_name_or_path,
|
876 |
unet=accelerator.unwrap_model(unet),
|
877 |
text_encoder=accelerator.unwrap_model(text_encoder),
|
|
|
881 |
subprocess.call('rm -r '+txt_dir, shell=True)
|
882 |
|
883 |
else:
|
884 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
885 |
args.pretrained_model_name_or_path,
|
886 |
unet=accelerator.unwrap_model(unet),
|
887 |
text_encoder=accelerator.unwrap_model(text_encoder),
|