linoyts HF staff commited on
Commit
e27ec02
1 Parent(s): f2b4569

Update clip_slider_pipeline.py

Browse files
Files changed (1) hide show
  1. clip_slider_pipeline.py +5 -4
clip_slider_pipeline.py CHANGED
@@ -18,7 +18,7 @@ class CLIPSlider:
18
  ):
19
 
20
  self.device = device
21
- self.pipe = sd_pipe.to(self.device)
22
  self.iterations = iterations
23
  if target_word != "" or opposite != "":
24
  self.avg_diff = self.find_latent_direction(target_word, opposite)
@@ -280,13 +280,14 @@ class CLIPSliderXL(CLIPSlider):
280
  prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
281
  prompt_embeds_list.append(prompt_embeds)
282
 
283
- prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
284
- pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
285
  end_time = time.time()
 
286
  print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
287
  torch.manual_seed(seed)
288
  start_time = time.time()
289
- image = self.pipe(prompt_embeds=prompt_embeds.to(torch.float16), pooled_prompt_embeds=pooled_prompt_embeds.to(torch.float16),
290
  **pipeline_kwargs).images[0]
291
  end_time = time.time()
292
  print(f"generation time - pipe: {end_time - start_time:.2f} ms")
 
18
  ):
19
 
20
  self.device = device
21
+ self.pipe = sd_pipe.to(self.device, torch.float16)
22
  self.iterations = iterations
23
  if target_word != "" or opposite != "":
24
  self.avg_diff = self.find_latent_direction(target_word, opposite)
 
280
  prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
281
  prompt_embeds_list.append(prompt_embeds)
282
 
283
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1).to(torch.float16)
284
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1).to(torch.float16)
285
  end_time = time.time()
286
+ print("prompt_embeds", prompt_embeds.dtype)
287
  print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
288
  torch.manual_seed(seed)
289
  start_time = time.time()
290
+ image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
291
  **pipeline_kwargs).images[0]
292
  end_time = time.time()
293
  print(f"generation time - pipe: {end_time - start_time:.2f} ms")