File size: 3,942 Bytes
a55bafa
 
 
452abeb
a55bafa
 
 
 
 
 
 
 
 
 
 
452abeb
a55bafa
452abeb
a55bafa
452abeb
 
a55bafa
 
 
 
 
 
 
 
 
d7c590b
 
a55bafa
 
 
 
 
8cfdc75
a55bafa
 
8cfdc75
a55bafa
 
452abeb
8cfdc75
 
 
a55bafa
 
452abeb
a55bafa
 
 
452abeb
a55bafa
 
 
 
 
 
 
 
 
452abeb
a55bafa
 
 
452abeb
a55bafa
 
 
452abeb
8cfdc75
a55bafa
452abeb
a55bafa
 
452abeb
 
 
 
a55bafa
 
452abeb
a55bafa
 
452abeb
a55bafa
452abeb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3
from diffusers import DiffusionPipeline, DDIMScheduler
import argparse
from diffusers.pipelines.stable_diffusion import safety_checker
import torch
from datasets import load_dataset
import PIL

IMAGE_OUTPUT_SIZE = (256, 256)
NUM_INFERENCE_STEPS = 100

def resize(image: PIL.Image):
    return image.resize(IMAGE_OUTPUT_SIZE, resample=PIL.Image.Resampling.LANCZOS)

def get_sd_eval(ckpt, guidance_scale=7.5):
    pipe = DiffusionPipeline.from_pretrained(ckpt, torch_dtype=torch.float16, safety_checker=None)
    pipe.to("cuda")
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

    def sd_eval(prompt, generator=None):
        images = pipe(prompt, generator=generator, num_inference_steps=NUM_INFERENCE_STEPS, guidance_scale=guidance_scale).images
        images = [resize(image) for image in images]
        return images

    return sd_eval

def get_karlo_eval(ckpt):
    pipe = DiffusionPipeline.from_pretrained(ckpt, torch_dtype=torch.float16)
    pipe.to("cuda")

    def karlo_eval(prompt, generator=None):
        images = pipe(prompt, prior_num_inference_steps=50, generator=generator, decoder_num_inference_steps=NUM_INFERENCE_STEPS).images
        return images

    return karlo_eval

def get_if_eval(ckpt):
    pipe_low = DiffusionPipeline.from_pretrained(ckpt, safety_checker=None, watermarker=None, torch_dtype=torch.float16, variant="fp16")
    pipe_low.enable_model_cpu_offload()

    pipe_up = DiffusionPipeline.from_pretrained("DeepFloyd/IF-II-L-v1.0", safety_checker=None, watermarker=None, text_encoder=pipe_low.text_encoder, torch_dtype=torch.float16, variant="fp16")
    pipe_up.enable_model_cpu_offload()

    def if_eval(prompt, generator=None):
        prompt_embeds, negative_prompt_embeds = pipe_low.encode_prompt(prompt)
        images = pipe_low(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, num_inference_steps=NUM_INFERENCE_STEPS, generator=generator, output_type="pt").images
        images = pipe_up(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, image=images, num_inference_steps=NUM_INFERENCE_STEPS, generator=generator).images
        return images

    return if_eval

MODELS = {
    "runwayml/stable-diffusion-v1-5": get_sd_eval,
    "stabilityai/stable-diffusion-2-1": get_sd_eval,
    "kakaobrain/karlo-alpha": get_karlo_eval,
    "DeepFloyd/IF-I-XL-v1.0": get_if_eval,
}




if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Run Parti Prompt Evaluation')
    parser.add_argument('model_repo_or_id', type=str, help='ID or URL of the model repository.')
    parser.add_argument('--dataset_repo_or_id', type=str, default='diffusers/prompt_generations', help='ID or URL of the dataset repository (default: "diffusers/prompt_generations")')
    parser.add_argument('--batch_size', type=int, default=8, help="Batch size for the eval function")
    parser.add_argument('--upload_to_hub', action='store_true', help='whether to upload the dataset to the Hugging Face dataset hub')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')

    args = parser.parse_args()

    dataset = load_dataset("nateraw/parti-prompts")["train"]
    # dataset = dataset.select(range(4))

    eval_fn = MODELS[args.model_repo_or_id](args.model_repo_or_id)

    def map_fn(batch):
        generators = [torch.Generator(device="cuda").manual_seed(args.seed) for _ in range(args.batch_size)]
        batch["images"] = eval_fn(batch["Prompt"], generator=generators)
        batch["model_name"] = len(batch["images"]) * [args.model_repo_or_id]
        batch["seed"] = len(batch["images"]) * [args.seed]
        return batch

    dataset_images = dataset.map(map_fn, batched=True, batch_size=args.batch_size)

    if args.upload_to_hub:
        dataset_images.push_to_hub(args.dataset_repo_or_id)
    else:
        dataset_images.save_to_disk(args.dataset_repo_or_id.split("/")[-1])