Kohaku-Blueleaf commited on
Commit
9746259
β€’
1 Parent(s): 3798f72

cast to float

Browse files
Files changed (2) hide show
  1. app.py +13 -12
  2. diff.py +13 -1
app.py CHANGED
@@ -66,17 +66,16 @@ def gen(
66
  encode_prompts(sdxl_pipe, full_prompt, DEFAULT_NEGATIVE_PROMPT)
67
  )
68
  set_seed(seed)
69
- with torch.autocast("cuda"):
70
- result = sdxl_pipe(
71
- prompt_embeds=prompt_embeds,
72
- negative_prompt_embeds=negative_prompt_embeds,
73
- pooled_prompt_embeds=pooled_embeds2,
74
- negative_pooled_prompt_embeds=neg_pooled_embeds2,
75
- num_inference_steps=24,
76
- width=1024,
77
- height=1024,
78
- guidance_scale=6.0,
79
- ).images[0]
80
  torch.cuda.empty_cache()
81
  t1 = time_ns()
82
 
@@ -138,7 +137,9 @@ click "Next" button until you get the dragon girl you like.
138
  value=list(DEFAULT_STYLE_LIST)[0],
139
  )
140
  submit = gr.Button("Next", variant="primary")
141
- dtg_output = gr.TextArea(label="DTG output", lines=9, show_copy_button=True)
 
 
142
  cost_time = gr.Markdown()
143
  with gr.Column(scale=4):
144
  result = gr.Image(label="Result", type="numpy", interactive=False)
 
66
  encode_prompts(sdxl_pipe, full_prompt, DEFAULT_NEGATIVE_PROMPT)
67
  )
68
  set_seed(seed)
69
+ result = sdxl_pipe(
70
+ prompt_embeds=prompt_embeds,
71
+ negative_prompt_embeds=negative_prompt_embeds,
72
+ pooled_prompt_embeds=pooled_embeds2,
73
+ negative_pooled_prompt_embeds=neg_pooled_embeds2,
74
+ num_inference_steps=24,
75
+ width=1024,
76
+ height=1024,
77
+ guidance_scale=6.0,
78
+ ).images[0]
 
79
  torch.cuda.empty_cache()
80
  t1 = time_ns()
81
 
 
137
  value=list(DEFAULT_STYLE_LIST)[0],
138
  )
139
  submit = gr.Button("Next", variant="primary")
140
+ dtg_output = gr.TextArea(
141
+ label="DTG output", lines=9, show_copy_button=True
142
+ )
143
  cost_time = gr.Markdown()
144
  with gr.Column(scale=4):
145
  result = gr.Image(label="Result", type="numpy", interactive=False)
diff.py CHANGED
@@ -5,6 +5,8 @@ from diffusers import StableDiffusionXLKDiffusionPipeline
5
  from k_diffusion.sampling import get_sigmas_polyexponential
6
  from k_diffusion.sampling import sample_dpmpp_2m_sde
7
 
 
 
8
 
9
  def set_timesteps_polyexponential(self, orig_sigmas, num_inference_steps, device=None):
10
  self.num_inference_steps = num_inference_steps
@@ -19,6 +21,16 @@ def set_timesteps_polyexponential(self, orig_sigmas, num_inference_steps, device
19
  self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas.new_zeros([1])])
20
 
21
 
 
 
 
 
 
 
 
 
 
 
22
  def load_model(model_id="KBlueLeaf/Kohaku-XL-Epsilon", device="cuda"):
23
  pipe: StableDiffusionXLKDiffusionPipeline
24
  pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
@@ -28,6 +40,7 @@ def load_model(model_id="KBlueLeaf/Kohaku-XL-Epsilon", device="cuda"):
28
  set_timesteps_polyexponential, pipe.scheduler, pipe.scheduler.sigmas
29
  )
30
  pipe.sampler = partial(sample_dpmpp_2m_sde, eta=0.35, solver_type="heun")
 
31
  return pipe
32
 
33
 
@@ -104,4 +117,3 @@ def encode_prompts(pipe: StableDiffusionXLKDiffusionPipeline, prompt, neg_prompt
104
  neg_pooled_embeds2 = torch.mean(torch.stack(neg_pooled_embeds2, dim=0), dim=0)
105
 
106
  return prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2
107
-
 
5
  from k_diffusion.sampling import get_sigmas_polyexponential
6
  from k_diffusion.sampling import sample_dpmpp_2m_sde
7
 
8
+ torch.set_float32_matmul_precision("mediun")
9
+
10
 
11
  def set_timesteps_polyexponential(self, orig_sigmas, num_inference_steps, device=None):
12
  self.num_inference_steps = num_inference_steps
 
21
  self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas.new_zeros([1])])
22
 
23
 
24
+ def model_forward(k_diffusion_model: torch.nn.Module):
25
+ orig_forward = k_diffusion_model.forward
26
+ def forward(*args, **kwargs):
27
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
28
+ result = orig_forward(*args, **kwargs)
29
+ return result.float()
30
+
31
+ return forward
32
+
33
+
34
  def load_model(model_id="KBlueLeaf/Kohaku-XL-Epsilon", device="cuda"):
35
  pipe: StableDiffusionXLKDiffusionPipeline
36
  pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
 
40
  set_timesteps_polyexponential, pipe.scheduler, pipe.scheduler.sigmas
41
  )
42
  pipe.sampler = partial(sample_dpmpp_2m_sde, eta=0.35, solver_type="heun")
43
+ pipe.k_diffusion_model.forward = model_forward(pipe.k_diffusion_model)
44
  return pipe
45
 
46
 
 
117
  neg_pooled_embeds2 = torch.mean(torch.stack(neg_pooled_embeds2, dim=0), dim=0)
118
 
119
  return prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2