patrickvonplaten commited on
Commit
35016b9
1 Parent(s): dbc57d2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +60 -2
README.md CHANGED
@@ -2,7 +2,7 @@
2
 
3
  This is a collection of scripts that can be useful for various tasks related to the [diffusers library](https://github.com/huggingface/diffusers)
4
 
5
- ## Test against original checkpoints
6
 
7
  **It's very important to have visually the exact same results as the original code bases.!**
8
 
@@ -37,4 +37,62 @@ image.save("/home/patrick_huggingface_co/images/aa_comp.png")
37
 
38
  Both commands should give the following image on a V100:
39
 
40
- ![image](https://huggingface.co/diffusers/tools/resolve/main/aa_orig_comp%20(6).png)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  This is a collection of scripts that can be useful for various tasks related to the [diffusers library](https://github.com/huggingface/diffusers)
4
 
5
+ ## 1. Test against original checkpoints
6
 
7
  **It's very important to have visually the exact same results as the original code bases.!**
8
 
 
37
 
38
  Both commands should give the following image on a V100:
39
 
40
+
41
+ ## 2. Test against [k-diffusion](https://github.com/crowsonkb/k-diffusion):
42
+
43
+ You can run the following script to compare against k-diffusion.
44
+
45
+ See results [here](https://huggingface.co/datasets/patrickvonplaten/images)
46
+
47
+ ```python
48
+ from diffusers import StableDiffusionKDiffusionPipeline, HeunDiscreteScheduler, StableDiffusionPipeline, DPMSolverMultistepScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler
49
+ import torch
50
+ import os
51
+
52
+ seed = 13
53
+ inference_steps = 25
54
+ #checkpoint = "CompVis/stable-diffusion-v1-4"
55
+ checkpoint = "stabilityai/stable-diffusion-2-1"
56
+ prompts = ["astronaut riding horse", "whale falling from sky", "magical forest", "highly photorealistic picture of johnny depp"]
57
+ prompts = 8 * ["highly photorealistic picture of johnny depp"]
58
+ #prompts = prompts[:1]
59
+ samplers = ["sample_dpmpp_2m", "sample_euler", "sample_heun", "sample_dpm_2", "sample_lms"]
60
+ #samplers = samplers[:1]
61
+
62
+ pipe = StableDiffusionKDiffusionPipeline.from_pretrained(checkpoint, torch_dtype=torch.float16, safety_checker=None)
63
+ pipe = pipe.to("cuda")
64
+
65
+ for i, prompt in enumerate(prompts):
66
+ prompt_f = f"{'_'.join(prompt.split())}_{i}"
67
+ for sampler in samplers:
68
+ pipe.set_scheduler(sampler)
69
+ torch.manual_seed(seed + i)
70
+ image = pipe(prompt, num_inference_steps=inference_steps).images[0]
71
+ checkpoint_f = f"{'--'.join(checkpoint.split('/'))}"
72
+ os.makedirs(f"/home/patrick_huggingface_co/images/{checkpoint_f}", exist_ok=True)
73
+ os.makedirs(f"/home/patrick_huggingface_co/images/{checkpoint_f}/{sampler}", exist_ok=True)
74
+ image.save(f"/home/patrick_huggingface_co/images/{checkpoint_f}/{sampler}/{prompt_f}.png")
75
+
76
+
77
+ pipe = StableDiffusionPipeline(**pipe.components)
78
+ pipe = pipe.to("cuda")
79
+
80
+ for i, prompt in enumerate(prompts):
81
+ prompt_f = f"{'_'.join(prompt.split())}_{i}"
82
+ for sampler in samplers:
83
+ if sampler == "sample_euler":
84
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
85
+ elif sampler == "sample_heun":
86
+ pipe.scheduler = HeunDiscreteScheduler.from_config(pipe.scheduler.config)
87
+ elif sampler == "sample_dpmpp_2m":
88
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
89
+ elif sampler == "sample_lms":
90
+ pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
91
+
92
+ torch.manual_seed(seed + i)
93
+ image = pipe(prompt, num_inference_steps=inference_steps).images[0]
94
+ checkpoint_f = f"{'--'.join(checkpoint.split('/'))}"
95
+ os.makedirs("/home/patrick_huggingface_co/images/{checkpoint_f}", exist_ok=True)
96
+ os.makedirs(f"/home/patrick_huggingface_co/images/{checkpoint_f}/{sampler}", exist_ok=True)
97
+ image.save(f"/home/patrick_huggingface_co/images/{checkpoint_f}/{sampler}/{prompt_f}_hf.png")
98
+ ```