poipiii commited on
Commit
e36655e
1 Parent(s): 07bacbd

test in latnent upcale

Browse files
Files changed (1) hide show
  1. pipeline.py +8 -7
pipeline.py CHANGED
@@ -872,25 +872,26 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
872
  # print(latent_model_input)
873
  print(latent_model_input.shape)
874
  # predict the noise residual
875
- noise_pred = self.unet(latent_model_input, t,
876
  encoder_hidden_states=text_embeddings).sample
877
  print("noise_pred")
878
  # print(noise_pred)
879
- print(noise_pred.shape)
880
 
881
  print("perform guidance")
882
  # perform guidance
883
  if do_classifier_free_guidance:
884
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
885
- noise_pred = noise_pred_uncond + guidance_scale * \
 
886
  (noise_pred_text - noise_pred_uncond)
887
  print("noise_pred after guidance")
888
- print(noise_pred.shape)
889
-
890
  print("compute the previous noisy sample")
891
  # compute the previous noisy sample x_t -> x_t-1
892
  upscale_latents = self.scheduler.step(
893
- noise_pred, t, upscale_latents, **extra_step_kwargs).prev_sample
894
  print(upscale_latents.shape)
895
 
896
  print("compute mask")
 
872
  # print(latent_model_input)
873
  print(latent_model_input.shape)
874
  # predict the noise residual
875
+ upscale_noise_pred = self.unet(latent_model_input, t,
876
  encoder_hidden_states=text_embeddings).sample
877
  print("noise_pred")
878
  # print(noise_pred)
879
+ print(upscale_noise_pred.shape)
880
 
881
  print("perform guidance")
882
  # perform guidance
883
  if do_classifier_free_guidance:
884
+ noise_pred_uncond, noise_pred_text = upscale_noise_pred.chunk(
885
+ 2)
886
+ upscale_noise_pred = noise_pred_uncond + guidance_scale * \
887
  (noise_pred_text - noise_pred_uncond)
888
  print("noise_pred after guidance")
889
+ print(upscale_noise_pred.shape)
890
+
891
  print("compute the previous noisy sample")
892
  # compute the previous noisy sample x_t -> x_t-1
893
  upscale_latents = self.scheduler.step(
894
+ upscale_noise_pred, t, upscale_latents, **extra_step_kwargs).prev_sample
895
  print(upscale_latents.shape)
896
 
897
  print("compute mask")