diffusion-xl / lib /inference.py
adamelliotfields's picture
Simplify loading and inference
b00d4fe
raw
history blame
4.81 kB
import time
from datetime import datetime
import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import PromptParser
from gradio import Error, Info, Progress
from spaces import GPU
from .loader import Loader
from .logger import Logger
from .utils import cuda_collect, get_output_types, timer
@GPU
def generate(
positive_prompt="",
negative_prompt="",
seed=None,
model="stabilityai/stable-diffusion-xl-base-1.0",
scheduler="Euler",
width=1024,
height=1024,
guidance_scale=6.0,
inference_steps=40,
deepcache=1,
scale=1,
num_images=1,
use_karras=False,
use_refiner=False,
progress=Progress(track_tqdm=True),
):
if not torch.cuda.is_available():
raise Error("CUDA not available")
if positive_prompt.strip() == "":
raise Error("You must enter a prompt")
KIND = "txt2img"
EMBEDDINGS_TYPE = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
start = time.perf_counter()
log = Logger("generate")
log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")
loader = Loader()
loader.load(
KIND,
model,
scheduler,
deepcache,
scale,
use_karras,
use_refiner,
progress,
)
refiner = loader.refiner
pipeline = loader.pipeline
upscaler = loader.upscaler
# Probably a typo in the config
if pipeline is None:
raise Error(f"Error loading {model}")
# Prompt embeddings for base and refiner
compel_1 = Compel(
text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2],
requires_pooled=[False, True],
returned_embeddings_type=EMBEDDINGS_TYPE,
dtype_for_device_getter=lambda _: pipeline.dtype,
device=pipeline.device,
)
compel_2 = Compel(
text_encoder=[pipeline.text_encoder_2],
tokenizer=[pipeline.tokenizer_2],
requires_pooled=[True],
returned_embeddings_type=EMBEDDINGS_TYPE,
dtype_for_device_getter=lambda _: pipeline.dtype,
device=pipeline.device,
)
# https://pytorch.org/docs/stable/generated/torch.manual_seed.html
if seed is None or seed < 0:
seed = int(datetime.now().timestamp() * 1e6) % (2**64)
# Increment the seed after each iteration
images = []
current_seed = seed
for i in range(num_images):
try:
generator = torch.Generator(device=pipeline.device).manual_seed(current_seed)
conditioning_1, pooled_1 = compel_1([positive_prompt, negative_prompt])
conditioning_2, pooled_2 = compel_2([positive_prompt, negative_prompt])
except PromptParser.ParsingException:
raise Error("Invalid prompt")
pipeline_output_type, refiner_output_type = get_output_types(scale, use_refiner)
pipeline_kwargs = {
"width": width,
"height": height,
"denoising_end": 0.8 if use_refiner else None,
"generator": generator,
"output_type": pipeline_output_type,
"guidance_scale": guidance_scale,
"num_inference_steps": inference_steps,
"prompt_embeds": conditioning_1[0:1],
"pooled_prompt_embeds": pooled_1[0:1],
"negative_prompt_embeds": conditioning_1[1:2],
"negative_pooled_prompt_embeds": pooled_1[1:2],
}
refiner_kwargs = {
"denoising_start": 0.8,
"generator": generator,
"output_type": refiner_output_type,
"guidance_scale": guidance_scale,
"num_inference_steps": inference_steps,
"prompt_embeds": conditioning_2[0:1],
"pooled_prompt_embeds": pooled_2[0:1],
"negative_prompt_embeds": conditioning_2[1:2],
"negative_pooled_prompt_embeds": pooled_2[1:2],
}
image = pipeline(**pipeline_kwargs).images[0]
if use_refiner:
refiner_kwargs["image"] = image
image = refiner(**refiner_kwargs).images[0]
# Use a tuple so gallery images get captions
images.append((image, str(current_seed)))
current_seed += 1
# Upscale
if scale > 1:
with timer(f"Upscaling {num_images} images {scale}x", logger=log.info):
for i, image in enumerate(images):
image = upscaler.predict(image[0])
seed = images[i][1]
images[i] = (image, seed)
# Flush cache after generating
cuda_collect()
end = time.perf_counter()
msg = f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {end - start:.2f}s"
log.info(msg)
if Info:
Info(msg)
return images