from __future__ import annotations import gc import pathlib import sys import gradio as gr import PIL.Image import numpy as np import torch from diffusers import StableDiffusionPipeline # sys.path.insert(0, './ReVersion') # below are original import os # import argparse # import torch from PIL import Image # from diffusers import StableDiffusionPipeline # sys.path.insert(0, './ReVersion') # from templates.templates import inference_templates import math """ Inference script for generating batch results """ def make_image_grid(imgs, rows, cols): assert len(imgs) == rows*cols w, h = imgs[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) grid_w, grid_h = grid.size for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h)) return grid def inference_fn( examples: list, prompt: str, num_samples: int, guidance_scale: float, ddim_steps: int, ) -> PIL.Image.Image: # select model_id model_id = pathlib.Path(examples[0]).stem # create inference pipeline if torch.cuda.is_available(): pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id),torch_dtype=torch.float16).to('cuda') else: pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id)).to('cpu') # single text prompt if prompt is not None: prompt_list = [prompt] else: prompt_list = [] for prompt in prompt_list: # insert relation prompt # prompt = prompt.lower().replace("", "").format(placeholder_string) prompt = prompt.lower().replace("", "").format("") # batch generation images = pipe(prompt, num_inference_steps=ddim_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_samples).images # save a grid of images image_grid = make_image_grid(images, rows=2, cols=math.ceil(num_samples/2)) print(image_grid) return image_grid if __name__ == "__main__": inference_fn()