LennyHood commited on
Commit
bc76f4d
1 Parent(s): 7c03a1d

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +62 -0
inference.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lcm_pipeline import LatentConsistencyModelPipeline
2
+ from lcm_scheduler import LCMScheduler
3
+
4
+ from diffusers import AutoencoderKL, UNet2DConditionModel
5
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
6
+ from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
7
+
8
+ import os
9
+ import torch
10
+ from tqdm import tqdm
11
+ from safetensors.torch import load_file
12
+
13
+ # Input Prompt:
14
+ prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair"
15
+
16
+ # Save Path:
17
+ save_path = "./lcm_images"
18
+ os.makedirs(save_path, exist_ok=True)
19
+
20
+
21
+ # Origin SD Model ID:
22
+ model_id = "digiplay/DreamShaper_7"
23
+
24
+
25
+ # Initalize Diffusers Model:
26
+ vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
27
+ text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
28
+ tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
29
+ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", device_map=None, low_cpu_mem_usage=False, local_files_only=True)
30
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(model_id, subfolder="safety_checker")
31
+ feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")
32
+
33
+
34
+ # Initalize Scheduler:
35
+ scheduler = LCMScheduler(beta_start=0.00085, beta_end=0.0120, beta_schedule="scaled_linear", prediction_type="epsilon")
36
+
37
+
38
+ # Replace the unet with LCM:
39
+ lcm_unet_ckpt = "./LCM_Dreamshaper_v7_4k.safetensors"
40
+ ckpt = load_file(lcm_unet_ckpt)
41
+ m, u = unet.load_state_dict(ckpt, strict=False)
42
+ if len(m) > 0:
43
+ print("missing keys:")
44
+ print(m)
45
+ if len(u) > 0:
46
+ print("unexpected keys:")
47
+ print(u)
48
+
49
+
50
+ # LCM Pipeline:
51
+ pipe = LatentConsistencyModelPipeline(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor)
52
+ pipe = pipe.to("cuda")
53
+
54
+
55
+ # Output Images:
56
+ images = pipe(prompt=prompt, num_images_per_prompt=4, num_inference_steps=4, guidance_scale=8.0, lcm_origin_steps=50).images
57
+
58
+ # Save Images:
59
+ for i in tqdm(range(len(images))):
60
+ output_path = os.path.join(save_path, "{}.png".format(i))
61
+ image = images[i]
62
+ image.save(output_path)