File size: 4,652 Bytes
5704d1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import argparse
import logging
from diffusers import AmusedPipeline
import os
from peft import PeftModel
from diffusers import UVit2DModel
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument("--style_descriptor", type=str, default="[V]")
parser.add_argument(
"--load_transformer_from",
type=str,
required=False,
default=None,
)
parser.add_argument(
"--load_transformer_lora_from",
type=str,
required=False,
default=None,
)
parser.add_argument("--device", type=str, default='cuda')
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--write_images_to", type=str, required=True)
args = parser.parse_args()
return args
def main(args):
prompts = [
f"A chihuahua in {args.style_descriptor} style",
f"A tabby cat in {args.style_descriptor} style",
f"A portrait of chihuahua in {args.style_descriptor} style",
f"An apple on the table in {args.style_descriptor} style",
f"A banana on the table in {args.style_descriptor} style",
f"A church on the street in {args.style_descriptor} style",
f"A church in the mountain in {args.style_descriptor} style",
f"A church in the field in {args.style_descriptor} style",
f"A church on the beach in {args.style_descriptor} style",
f"A chihuahua walking on the street in {args.style_descriptor} style",
f"A tabby cat walking on the street in {args.style_descriptor} style",
f"A portrait of tabby cat in {args.style_descriptor} style",
f"An apple on the dish in {args.style_descriptor} style",
f"A banana on the dish in {args.style_descriptor} style",
f"A human walking on the street in {args.style_descriptor} style",
f"A temple on the street in {args.style_descriptor} style",
f"A temple in the mountain in {args.style_descriptor} style",
f"A temple in the field in {args.style_descriptor} style",
f"A temple on the beach in {args.style_descriptor} style",
f"A chihuahua walking in the forest in {args.style_descriptor} style",
f"A tabby cat walking in the forest in {args.style_descriptor} style",
f"A portrait of human face in {args.style_descriptor} style",
f"An apple on the ground in {args.style_descriptor} style",
f"A banana on the ground in {args.style_descriptor} style",
f"A human walking in the forest in {args.style_descriptor} style",
f"A cabin on the street in {args.style_descriptor} style",
f"A cabin in the mountain in {args.style_descriptor} style",
f"A cabin in the field in {args.style_descriptor} style",
f"A cabin on the beach in {args.style_descriptor} style"
]
logger.warning(f"generating image for {prompts}")
logger.warning(f"loading models")
pipe_args = {}
if args.load_transformer_from is not None:
pipe_args["transformer"] = UVit2DModel.from_pretrained(args.load_transformer_from)
pipe = AmusedPipeline.from_pretrained(
pretrained_model_name_or_path=args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
**pipe_args
)
if args.load_transformer_lora_from is not None:
pipe.transformer = PeftModel.from_pretrained(
pipe.transformer, os.path.join(args.load_transformer_from), is_trainable=False
)
pipe.to(args.device)
logger.warning(f"generating images")
os.makedirs(args.write_images_to, exist_ok=True)
for prompt_idx in range(0, len(prompts), args.batch_size):
images = pipe(prompts[prompt_idx:prompt_idx+args.batch_size]).images
for image_idx, image in enumerate(images):
prompt = prompts[prompt_idx+image_idx]
image.save(os.path.join(args.write_images_to, prompt + ".png"))
if __name__ == "__main__":
main(parse_args()) |