diffusion / lib /inference.py
adamelliotfields's picture
Rewrite loading and inference
9769856 verified
raw
history blame
5.37 kB
import os
import time
from datetime import datetime
import torch
from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
from compel.prompt_parser import PromptParser
from gradio import Error, Info, Progress
from spaces import GPU, config
from .loader import get_loader
from .logger import Logger
from .utils import annotate_image, cuda_collect, resize_image, timer
@GPU
def generate(
positive_prompt="",
negative_prompt="",
image_input=None,
controlnet_input=None,
ip_adapter_input=None,
seed=None,
model="XpucT/Reliberate",
scheduler="UniPC",
controlnet_annotator="canny",
width=512,
height=512,
guidance_scale=6.0,
inference_steps=40,
denoising_strength=0.8,
deepcache_interval=1,
scale=1,
num_images=1,
use_karras=False,
use_ip_adapter_face=False,
_=Progress(track_tqdm=True),
):
if not torch.cuda.is_available():
raise Error("CUDA not available")
if positive_prompt.strip() == "":
raise Error("You must enter a prompt")
start = time.perf_counter()
log = Logger("generate")
log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")
KIND = "img2img" if image_input is not None else "txt2img"
KIND = f"controlnet_{KIND}" if controlnet_input is not None else KIND
EMBEDDINGS_TYPE = ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
FAST_NEGATIVE = "<fast_negative>" in negative_prompt
if ip_adapter_input:
IP_KIND = "full-face" if use_ip_adapter_face else "plus"
else:
IP_KIND = ""
# ZeroGPU is serverless so you want ephemeral instances
# You want a singleton on localhost so the pipeline stays in memory
loader = get_loader(singleton=not config.Config.zero_gpu)
loader.load(
KIND,
IP_KIND,
model,
scheduler,
controlnet_annotator,
deepcache_interval,
scale,
use_karras,
)
pipeline = loader.pipeline
upscaler = loader.upscaler
# Probably a typo in the config
if pipeline is None:
raise Error(f"Error loading {model}")
# Load fast negative embedding
if FAST_NEGATIVE:
embeddings_dir = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "embeddings")
)
pipeline.load_textual_inversion(
pretrained_model_name_or_path=f"{embeddings_dir}/fast_negative.pt",
token="<fast_negative>",
)
# Embed prompts with weights
compel = Compel(
device=pipeline.device,
tokenizer=pipeline.tokenizer,
truncate_long_prompts=False,
text_encoder=pipeline.text_encoder,
returned_embeddings_type=EMBEDDINGS_TYPE,
dtype_for_device_getter=lambda _: pipeline.dtype,
textual_inversion_manager=DiffusersTextualInversionManager(pipeline),
)
# https://pytorch.org/docs/stable/generated/torch.manual_seed.html
if seed is None or seed < 0:
seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
# Increment the seed after each iteration
images = []
current_seed = seed
for i in range(num_images):
try:
generator = torch.Generator(device=pipeline.device).manual_seed(current_seed)
positive_embeds, negative_embeds = compel.pad_conditioning_tensors_to_same_length(
[compel(positive_prompt), compel(negative_prompt)]
)
except PromptParser.ParsingException:
raise Error("Invalid prompt")
kwargs = {
"width": width,
"height": height,
"generator": generator,
"prompt_embeds": positive_embeds,
"guidance_scale": guidance_scale,
"num_inference_steps": inference_steps,
"negative_prompt_embeds": negative_embeds,
"output_type": "np" if scale > 1 else "pil",
}
if KIND == "img2img" or KIND == "controlnet_img2img":
kwargs["strength"] = denoising_strength
kwargs["image"] = resize_image(image_input, (width, height))
if KIND == "controlnet_txt2img":
kwargs["image"] = annotate_image(controlnet_input, controlnet_annotator)
if KIND == "controlnet_img2img":
kwargs["control_image"] = annotate_image(controlnet_input, controlnet_annotator)
if IP_KIND:
# No size means preserve aspect ratio
kwargs["ip_adapter_image"] = resize_image(ip_adapter_input)
try:
image = pipeline(**kwargs).images[0]
images.append((image, str(current_seed))) # tuple with seed for gallery caption
current_seed += 1
finally:
if FAST_NEGATIVE:
pipeline.unload_textual_inversion()
# Upscale
if scale > 1:
with timer(f"Upscaling {num_images} images {scale}x", logger=log.info):
for i, image in enumerate(images):
image = upscaler.predict(image[0])
seed = images[i][1]
images[i] = (image, seed) # tuple again
end = time.perf_counter()
msg = f"Generating {len(images)} image{'s' if len(images) > 1 else ''} took {end - start:.2f}s"
log.info(msg)
if Info:
Info(msg)
# Flush cache before returning
cuda_collect()
return images