patrickvonplaten commited on
Commit
8cfdc75
1 Parent(s): 452abeb
Files changed (1) hide show
  1. parti_prompts.py +6 -4
parti_prompts.py CHANGED
@@ -35,15 +35,16 @@ def get_karlo_eval(ckpt):
35
  return karlo_eval
36
 
37
  def get_if_eval(ckpt):
38
- pipe_low = DiffusionPipeline.from_pretrained(ckpt, safety_checker=None, torch_dtype=torch.float16)
39
  pipe_low.enable_model_cpu_offload()
40
 
41
- pipe_up = DiffusionPipeline.from_pretrained("DeepFloyd/IF-II-L-v1.0", safety_checker=None, text_encoder=pipe_low.text_encoder, torch_dtype=torch.float16)
42
  pipe_up.enable_model_cpu_offload()
43
 
44
  def if_eval(prompt, generator=None):
45
- images = pipe_low(prompt, num_inference_steps=NUM_INFERENCE_STEPS, generator=generator, output_type="pt").images
46
- images = pipe_up(promtp=prompt, images=images, num_inference_steps=NUM_INFERENCE_STEPS).images
 
47
  return images
48
 
49
  return if_eval
@@ -69,6 +70,7 @@ if __name__ == "__main__":
69
  args = parser.parse_args()
70
 
71
  dataset = load_dataset("nateraw/parti-prompts")["train"]
 
72
 
73
  eval_fn = MODELS[args.model_repo_or_id](args.model_repo_or_id)
74
 
 
35
  return karlo_eval
36
 
37
  def get_if_eval(ckpt):
38
+ pipe_low = DiffusionPipeline.from_pretrained(ckpt, safety_checker=None, watermarker=None, torch_dtype=torch.float16, variant="fp16")
39
  pipe_low.enable_model_cpu_offload()
40
 
41
+ pipe_up = DiffusionPipeline.from_pretrained("DeepFloyd/IF-II-L-v1.0", safety_checker=None, watermarker=None, text_encoder=pipe_low.text_encoder, torch_dtype=torch.float16, variant="fp16")
42
  pipe_up.enable_model_cpu_offload()
43
 
44
  def if_eval(prompt, generator=None):
45
+ prompt_embeds, negative_prompt_embeds = pipe_low.encode_prompt(prompt)
46
+ images = pipe_low(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, num_inference_steps=NUM_INFERENCE_STEPS, generator=generator, output_type="pt").images
47
+ images = pipe_up(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, image=images, num_inference_steps=NUM_INFERENCE_STEPS, generator=generator).images
48
  return images
49
 
50
  return if_eval
 
70
  args = parser.parse_args()
71
 
72
  dataset = load_dataset("nateraw/parti-prompts")["train"]
73
+ # dataset = dataset.select(range(4))
74
 
75
  eval_fn = MODELS[args.model_repo_or_id](args.model_repo_or_id)
76