File size: 3,942 Bytes
a55bafa 452abeb a55bafa 452abeb a55bafa 452abeb a55bafa 452abeb a55bafa d7c590b a55bafa 8cfdc75 a55bafa 8cfdc75 a55bafa 452abeb 8cfdc75 a55bafa 452abeb a55bafa 452abeb a55bafa 452abeb a55bafa 452abeb a55bafa 452abeb 8cfdc75 a55bafa 452abeb a55bafa 452abeb a55bafa 452abeb a55bafa 452abeb a55bafa 452abeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
#!/usr/bin/env python3
from diffusers import DiffusionPipeline, DDIMScheduler
import argparse
from diffusers.pipelines.stable_diffusion import safety_checker
import torch
from datasets import load_dataset
import PIL
IMAGE_OUTPUT_SIZE = (256, 256)
NUM_INFERENCE_STEPS = 100
def resize(image: PIL.Image):
return image.resize(IMAGE_OUTPUT_SIZE, resample=PIL.Image.Resampling.LANCZOS)
def get_sd_eval(ckpt, guidance_scale=7.5):
pipe = DiffusionPipeline.from_pretrained(ckpt, torch_dtype=torch.float16, safety_checker=None)
pipe.to("cuda")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
def sd_eval(prompt, generator=None):
images = pipe(prompt, generator=generator, num_inference_steps=NUM_INFERENCE_STEPS, guidance_scale=guidance_scale).images
images = [resize(image) for image in images]
return images
return sd_eval
def get_karlo_eval(ckpt):
pipe = DiffusionPipeline.from_pretrained(ckpt, torch_dtype=torch.float16)
pipe.to("cuda")
def karlo_eval(prompt, generator=None):
images = pipe(prompt, prior_num_inference_steps=50, generator=generator, decoder_num_inference_steps=NUM_INFERENCE_STEPS).images
return images
return karlo_eval
def get_if_eval(ckpt):
pipe_low = DiffusionPipeline.from_pretrained(ckpt, safety_checker=None, watermarker=None, torch_dtype=torch.float16, variant="fp16")
pipe_low.enable_model_cpu_offload()
pipe_up = DiffusionPipeline.from_pretrained("DeepFloyd/IF-II-L-v1.0", safety_checker=None, watermarker=None, text_encoder=pipe_low.text_encoder, torch_dtype=torch.float16, variant="fp16")
pipe_up.enable_model_cpu_offload()
def if_eval(prompt, generator=None):
prompt_embeds, negative_prompt_embeds = pipe_low.encode_prompt(prompt)
images = pipe_low(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, num_inference_steps=NUM_INFERENCE_STEPS, generator=generator, output_type="pt").images
images = pipe_up(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, image=images, num_inference_steps=NUM_INFERENCE_STEPS, generator=generator).images
return images
return if_eval
MODELS = {
"runwayml/stable-diffusion-v1-5": get_sd_eval,
"stabilityai/stable-diffusion-2-1": get_sd_eval,
"kakaobrain/karlo-alpha": get_karlo_eval,
"DeepFloyd/IF-I-XL-v1.0": get_if_eval,
}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Run Parti Prompt Evaluation')
parser.add_argument('model_repo_or_id', type=str, help='ID or URL of the model repository.')
parser.add_argument('--dataset_repo_or_id', type=str, default='diffusers/prompt_generations', help='ID or URL of the dataset repository (default: "diffusers/prompt_generations")')
parser.add_argument('--batch_size', type=int, default=8, help="Batch size for the eval function")
parser.add_argument('--upload_to_hub', action='store_true', help='whether to upload the dataset to the Hugging Face dataset hub')
parser.add_argument('--seed', type=int, default=0, help='Random seed')
args = parser.parse_args()
dataset = load_dataset("nateraw/parti-prompts")["train"]
# dataset = dataset.select(range(4))
eval_fn = MODELS[args.model_repo_or_id](args.model_repo_or_id)
def map_fn(batch):
generators = [torch.Generator(device="cuda").manual_seed(args.seed) for _ in range(args.batch_size)]
batch["images"] = eval_fn(batch["Prompt"], generator=generators)
batch["model_name"] = len(batch["images"]) * [args.model_repo_or_id]
batch["seed"] = len(batch["images"]) * [args.seed]
return batch
dataset_images = dataset.map(map_fn, batched=True, batch_size=args.batch_size)
if args.upload_to_hub:
dataset_images.push_to_hub(args.dataset_repo_or_id)
else:
dataset_images.save_to_disk(args.dataset_repo_or_id.split("/")[-1])
|