Spaces:
Paused
Paused
# Prediction interface for Cog ⚙️ | |
# https://github.com/replicate/cog/blob/main/docs/python.md | |
import os | |
import sys | |
sys.path.append(os.path.join(os.path.dirname(__file__), "..")) | |
sys.path.append(os.path.join(os.path.dirname(__file__), "../gradio_demo")) | |
import cv2 | |
import time | |
import torch | |
import mimetypes | |
import subprocess | |
import numpy as np | |
from typing import List | |
from cog import BasePredictor, Input, Path | |
import PIL | |
from PIL import Image | |
import diffusers | |
from diffusers import LCMScheduler | |
from diffusers.utils import load_image | |
from diffusers.models import ControlNetModel | |
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel | |
from model_util import get_torch_device | |
from insightface.app import FaceAnalysis | |
from transformers import CLIPImageProcessor | |
from controlnet_util import openpose, get_depth_map, get_canny_image | |
from diffusers.pipelines.stable_diffusion.safety_checker import ( | |
StableDiffusionSafetyChecker, | |
) | |
from pipeline_stable_diffusion_xl_instantid_full import ( | |
StableDiffusionXLInstantIDPipeline, | |
draw_kps, | |
) | |
mimetypes.add_type("image/webp", ".webp") | |
# GPU global variables | |
DEVICE = get_torch_device() | |
DTYPE = torch.float16 if str(DEVICE).__contains__("cuda") else torch.float32 | |
# for `ip-adapter`, `ControlNetModel`, and `stable-diffusion-xl-base-1.0` | |
CHECKPOINTS_CACHE = "./checkpoints" | |
CHECKPOINTS_URL = "https://weights.replicate.delivery/default/InstantID/checkpoints.tar" | |
# for `models/antelopev2` | |
MODELS_CACHE = "./models" | |
MODELS_URL = "https://weights.replicate.delivery/default/InstantID/models.tar" | |
# for the safety checker | |
SAFETY_CACHE = "./safety-cache" | |
FEATURE_EXTRACTOR = "./feature-extractor" | |
SAFETY_URL = "https://weights.replicate.delivery/default/playgroundai/safety-cache.tar" | |
SDXL_NAME_TO_PATHLIKE = { | |
# These are all huggingface models that we host via gcp + pget | |
"stable-diffusion-xl-base-1.0": { | |
"slug": "stabilityai/stable-diffusion-xl-base-1.0", | |
"url": "https://weights.replicate.delivery/default/InstantID/models--stabilityai--stable-diffusion-xl-base-1.0.tar", | |
"path": "checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0", | |
}, | |
"afrodite-xl-v2": { | |
"slug": "stablediffusionapi/afrodite-xl-v2", | |
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--afrodite-xl-v2.tar", | |
"path": "checkpoints/models--stablediffusionapi--afrodite-xl-v2", | |
}, | |
"albedobase-xl-20": { | |
"slug": "stablediffusionapi/albedobase-xl-20", | |
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--albedobase-xl-20.tar", | |
"path": "checkpoints/models--stablediffusionapi--albedobase-xl-20", | |
}, | |
"albedobase-xl-v13": { | |
"slug": "stablediffusionapi/albedobase-xl-v13", | |
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--albedobase-xl-v13.tar", | |
"path": "checkpoints/models--stablediffusionapi--albedobase-xl-v13", | |
}, | |
"animagine-xl-30": { | |
"slug": "stablediffusionapi/animagine-xl-30", | |
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--animagine-xl-30.tar", | |
"path": "checkpoints/models--stablediffusionapi--animagine-xl-30", | |
}, | |
"anime-art-diffusion-xl": { | |
"slug": "stablediffusionapi/anime-art-diffusion-xl", | |
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--anime-art-diffusion-xl.tar", | |
"path": "checkpoints/models--stablediffusionapi--anime-art-diffusion-xl", | |
}, | |
"anime-illust-diffusion-xl": { | |
"slug": "stablediffusionapi/anime-illust-diffusion-xl", | |
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--anime-illust-diffusion-xl.tar", | |
"path": "checkpoints/models--stablediffusionapi--anime-illust-diffusion-xl", | |
}, | |
"dreamshaper-xl": { | |
"slug": "stablediffusionapi/dreamshaper-xl", | |
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--dreamshaper-xl.tar", | |
"path": "checkpoints/models--stablediffusionapi--dreamshaper-xl", | |
}, | |
"dynavision-xl-v0610": { | |
"slug": "stablediffusionapi/dynavision-xl-v0610", | |
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--dynavision-xl-v0610.tar", | |
"path": "checkpoints/models--stablediffusionapi--dynavision-xl-v0610", | |
}, | |
"guofeng4-xl": { | |
"slug": "stablediffusionapi/guofeng4-xl", | |
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--guofeng4-xl.tar", | |
"path": "checkpoints/models--stablediffusionapi--guofeng4-xl", | |
}, | |
"juggernaut-xl-v8": { | |
"slug": "stablediffusionapi/juggernaut-xl-v8", | |
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--juggernaut-xl-v8.tar", | |
"path": "checkpoints/models--stablediffusionapi--juggernaut-xl-v8", | |
}, | |
"nightvision-xl-0791": { | |
"slug": "stablediffusionapi/nightvision-xl-0791", | |
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--nightvision-xl-0791.tar", | |
"path": "checkpoints/models--stablediffusionapi--nightvision-xl-0791", | |
}, | |
"omnigen-xl": { | |
"slug": "stablediffusionapi/omnigen-xl", | |
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--omnigen-xl.tar", | |
"path": "checkpoints/models--stablediffusionapi--omnigen-xl", | |
}, | |
"pony-diffusion-v6-xl": { | |
"slug": "stablediffusionapi/pony-diffusion-v6-xl", | |
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--pony-diffusion-v6-xl.tar", | |
"path": "checkpoints/models--stablediffusionapi--pony-diffusion-v6-xl", | |
}, | |
"protovision-xl-high-fidel": { | |
"slug": "stablediffusionapi/protovision-xl-high-fidel", | |
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--protovision-xl-high-fidel.tar", | |
"path": "checkpoints/models--stablediffusionapi--protovision-xl-high-fidel", | |
}, | |
"RealVisXL_V3.0_Turbo": { | |
"slug": "SG161222/RealVisXL_V3.0_Turbo", | |
"url": "https://weights.replicate.delivery/default/InstantID/models--SG161222--RealVisXL_V3.0_Turbo.tar", | |
"path": "checkpoints/models--SG161222--RealVisXL_V3.0_Turbo", | |
}, | |
"RealVisXL_V4.0_Lightning": { | |
"slug": "SG161222/RealVisXL_V4.0_Lightning", | |
"url": "https://weights.replicate.delivery/default/InstantID/models--SG161222--RealVisXL_V4.0_Lightning.tar", | |
"path": "checkpoints/models--SG161222--RealVisXL_V4.0_Lightning", | |
}, | |
} | |
def convert_from_cv2_to_image(img: np.ndarray) -> Image: | |
return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | |
def convert_from_image_to_cv2(img: Image) -> np.ndarray: | |
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
def resize_img( | |
input_image, | |
max_side=1280, | |
min_side=1024, | |
size=None, | |
pad_to_max_side=False, | |
mode=PIL.Image.BILINEAR, | |
base_pixel_number=64, | |
): | |
w, h = input_image.size | |
if size is not None: | |
w_resize_new, h_resize_new = size | |
else: | |
ratio = min_side / min(h, w) | |
w, h = round(ratio * w), round(ratio * h) | |
ratio = max_side / max(h, w) | |
input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode) | |
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number | |
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number | |
input_image = input_image.resize([w_resize_new, h_resize_new], mode) | |
if pad_to_max_side: | |
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 | |
offset_x = (max_side - w_resize_new) // 2 | |
offset_y = (max_side - h_resize_new) // 2 | |
res[offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new] = ( | |
np.array(input_image) | |
) | |
input_image = Image.fromarray(res) | |
return input_image | |
def download_weights(url, dest): | |
start = time.time() | |
print("[!] Initiating download from URL: ", url) | |
print("[~] Destination path: ", dest) | |
command = ["pget", "-vf", url, dest] | |
if ".tar" in url: | |
command.append("-x") | |
try: | |
subprocess.check_call(command, close_fds=False) | |
except subprocess.CalledProcessError as e: | |
print( | |
f"[ERROR] Failed to download weights. Command '{' '.join(e.cmd)}' returned non-zero exit status {e.returncode}." | |
) | |
raise | |
print("[+] Download completed in: ", time.time() - start, "seconds") | |
class Predictor(BasePredictor): | |
def setup(self) -> None: | |
"""Load the model into memory to make running multiple predictions efficient""" | |
if not os.path.exists(CHECKPOINTS_CACHE): | |
download_weights(CHECKPOINTS_URL, CHECKPOINTS_CACHE) | |
if not os.path.exists(MODELS_CACHE): | |
download_weights(MODELS_URL, MODELS_CACHE) | |
self.face_detection_input_width, self.face_detection_input_height = 640, 640 | |
self.app = FaceAnalysis( | |
name="antelopev2", | |
root="./", | |
providers=["CUDAExecutionProvider", "CPUExecutionProvider"], | |
) | |
self.app.prepare(ctx_id=0, det_size=(self.face_detection_input_width, self.face_detection_input_height)) | |
# Path to InstantID models | |
self.face_adapter = f"./checkpoints/ip-adapter.bin" | |
controlnet_path = f"./checkpoints/ControlNetModel" | |
# Load pipeline face ControlNetModel | |
self.controlnet_identitynet = ControlNetModel.from_pretrained( | |
controlnet_path, | |
torch_dtype=DTYPE, | |
cache_dir=CHECKPOINTS_CACHE, | |
local_files_only=True, | |
) | |
self.setup_extra_controlnets() | |
self.load_weights("stable-diffusion-xl-base-1.0") | |
self.setup_safety_checker() | |
def setup_safety_checker(self): | |
print(f"[~] Seting up safety checker") | |
if not os.path.exists(SAFETY_CACHE): | |
download_weights(SAFETY_URL, SAFETY_CACHE) | |
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( | |
SAFETY_CACHE, | |
torch_dtype=DTYPE, | |
local_files_only=True, | |
) | |
self.safety_checker.to(DEVICE) | |
self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR) | |
def run_safety_checker(self, image): | |
safety_checker_input = self.feature_extractor(image, return_tensors="pt").to( | |
DEVICE | |
) | |
np_image = np.array(image) | |
image, has_nsfw_concept = self.safety_checker( | |
images=[np_image], | |
clip_input=safety_checker_input.pixel_values.to(DTYPE), | |
) | |
return image, has_nsfw_concept | |
def load_weights(self, sdxl_weights): | |
self.base_weights = sdxl_weights | |
weights_info = SDXL_NAME_TO_PATHLIKE[self.base_weights] | |
download_url = weights_info["url"] | |
path_to_weights_dir = weights_info["path"] | |
if not os.path.exists(path_to_weights_dir): | |
download_weights(download_url, path_to_weights_dir) | |
is_hugging_face_model = "slug" in weights_info.keys() | |
path_to_weights_file = os.path.join( | |
path_to_weights_dir, | |
weights_info.get("file", ""), | |
) | |
print(f"[~] Loading new SDXL weights: {path_to_weights_file}") | |
if is_hugging_face_model: | |
self.pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( | |
weights_info["slug"], | |
controlnet=[self.controlnet_identitynet], | |
torch_dtype=DTYPE, | |
cache_dir=CHECKPOINTS_CACHE, | |
local_files_only=True, | |
safety_checker=None, | |
feature_extractor=None, | |
) | |
self.pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config( | |
self.pipe.scheduler.config | |
) | |
else: # e.g. .safetensors, NOTE: This functionality is not being used right now | |
self.pipe.from_single_file( | |
path_to_weights_file, | |
controlnet=self.controlnet_identitynet, | |
torch_dtype=DTYPE, | |
cache_dir=CHECKPOINTS_CACHE, | |
) | |
self.pipe.load_ip_adapter_instantid(self.face_adapter) | |
self.setup_lcm_lora() | |
self.pipe.cuda() | |
def setup_lcm_lora(self): | |
print(f"[~] Seting up LCM (just in case)") | |
lcm_lora_key = "models--latent-consistency--lcm-lora-sdxl" | |
lcm_lora_path = f"checkpoints/{lcm_lora_key}" | |
if not os.path.exists(lcm_lora_path): | |
download_weights( | |
f"https://weights.replicate.delivery/default/InstantID/{lcm_lora_key}.tar", | |
lcm_lora_path, | |
) | |
self.pipe.load_lora_weights( | |
"latent-consistency/lcm-lora-sdxl", | |
cache_dir=CHECKPOINTS_CACHE, | |
local_files_only=True, | |
weight_name="pytorch_lora_weights.safetensors", | |
) | |
self.pipe.disable_lora() | |
def setup_extra_controlnets(self): | |
print(f"[~] Seting up pose, canny, depth ControlNets") | |
controlnet_pose_model = "thibaud/controlnet-openpose-sdxl-1.0" | |
controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0" | |
controlnet_depth_model = "diffusers/controlnet-depth-sdxl-1.0-small" | |
for controlnet_key in [ | |
"models--diffusers--controlnet-canny-sdxl-1.0", | |
"models--diffusers--controlnet-depth-sdxl-1.0-small", | |
"models--thibaud--controlnet-openpose-sdxl-1.0", | |
]: | |
controlnet_path = f"checkpoints/{controlnet_key}" | |
if not os.path.exists(controlnet_path): | |
download_weights( | |
f"https://weights.replicate.delivery/default/InstantID/{controlnet_key}.tar", | |
controlnet_path, | |
) | |
controlnet_pose = ControlNetModel.from_pretrained( | |
controlnet_pose_model, | |
torch_dtype=DTYPE, | |
cache_dir=CHECKPOINTS_CACHE, | |
local_files_only=True, | |
).to(DEVICE) | |
controlnet_canny = ControlNetModel.from_pretrained( | |
controlnet_canny_model, | |
torch_dtype=DTYPE, | |
cache_dir=CHECKPOINTS_CACHE, | |
local_files_only=True, | |
).to(DEVICE) | |
controlnet_depth = ControlNetModel.from_pretrained( | |
controlnet_depth_model, | |
torch_dtype=DTYPE, | |
cache_dir=CHECKPOINTS_CACHE, | |
local_files_only=True, | |
).to(DEVICE) | |
self.controlnet_map = { | |
"pose": controlnet_pose, | |
"canny": controlnet_canny, | |
"depth": controlnet_depth, | |
} | |
self.controlnet_map_fn = { | |
"pose": openpose, | |
"canny": get_canny_image, | |
"depth": get_depth_map, | |
} | |
def generate_image( | |
self, | |
face_image_path, | |
pose_image_path, | |
prompt, | |
negative_prompt, | |
num_steps, | |
identitynet_strength_ratio, | |
adapter_strength_ratio, | |
pose_strength, | |
canny_strength, | |
depth_strength, | |
controlnet_selection, | |
guidance_scale, | |
seed, | |
scheduler, | |
enable_LCM, | |
enhance_face_region, | |
num_images_per_prompt, | |
): | |
if enable_LCM: | |
self.pipe.enable_lora() | |
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config) | |
else: | |
self.pipe.disable_lora() | |
scheduler_class_name = scheduler.split("-")[0] | |
add_kwargs = {} | |
if len(scheduler.split("-")) > 1: | |
add_kwargs["use_karras_sigmas"] = True | |
if len(scheduler.split("-")) > 2: | |
add_kwargs["algorithm_type"] = "sde-dpmsolver++" | |
scheduler = getattr(diffusers, scheduler_class_name) | |
self.pipe.scheduler = scheduler.from_config( | |
self.pipe.scheduler.config, | |
**add_kwargs, | |
) | |
if face_image_path is None: | |
raise Exception( | |
f"Cannot find any input face `image`! Please upload the face `image`" | |
) | |
face_image = load_image(face_image_path) | |
face_image = resize_img(face_image) | |
face_image_cv2 = convert_from_image_to_cv2(face_image) | |
height, width, _ = face_image_cv2.shape | |
# Extract face features | |
face_info = self.app.get(face_image_cv2) | |
if len(face_info) == 0: | |
raise Exception( | |
"Face detector could not find a face in the `image`. Please use a different `image` as input." | |
) | |
face_info = sorted( | |
face_info, | |
key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1], | |
)[ | |
-1 | |
] # only use the maximum face | |
face_emb = face_info["embedding"] | |
face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"]) | |
img_controlnet = face_image | |
if pose_image_path is not None: | |
pose_image = load_image(pose_image_path) | |
pose_image = resize_img(pose_image, max_side=1024) | |
img_controlnet = pose_image | |
pose_image_cv2 = convert_from_image_to_cv2(pose_image) | |
face_info = self.app.get(pose_image_cv2) | |
if len(face_info) == 0: | |
raise Exception( | |
"Face detector could not find a face in the `pose_image`. Please use a different `pose_image` as input." | |
) | |
face_info = face_info[-1] | |
face_kps = draw_kps(pose_image, face_info["kps"]) | |
width, height = face_kps.size | |
if enhance_face_region: | |
control_mask = np.zeros([height, width, 3]) | |
x1, y1, x2, y2 = face_info["bbox"] | |
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
control_mask[y1:y2, x1:x2] = 255 | |
control_mask = Image.fromarray(control_mask.astype(np.uint8)) | |
else: | |
control_mask = None | |
if len(controlnet_selection) > 0: | |
controlnet_scales = { | |
"pose": pose_strength, | |
"canny": canny_strength, | |
"depth": depth_strength, | |
} | |
self.pipe.controlnet = MultiControlNetModel( | |
[self.controlnet_identitynet] | |
+ [self.controlnet_map[s] for s in controlnet_selection] | |
) | |
control_scales = [float(identitynet_strength_ratio)] + [ | |
controlnet_scales[s] for s in controlnet_selection | |
] | |
control_images = [face_kps] + [ | |
self.controlnet_map_fn[s](img_controlnet).resize((width, height)) | |
for s in controlnet_selection | |
] | |
else: | |
self.pipe.controlnet = self.controlnet_identitynet | |
control_scales = float(identitynet_strength_ratio) | |
control_images = face_kps | |
generator = torch.Generator(device=DEVICE).manual_seed(seed) | |
print("Start inference...") | |
print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}") | |
self.pipe.set_ip_adapter_scale(adapter_strength_ratio) | |
images = self.pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
image_embeds=face_emb, | |
image=control_images, | |
control_mask=control_mask, | |
controlnet_conditioning_scale=control_scales, | |
num_inference_steps=num_steps, | |
guidance_scale=guidance_scale, | |
height=height, | |
width=width, | |
generator=generator, | |
num_images_per_prompt=num_images_per_prompt, | |
).images | |
return images | |
def predict( | |
self, | |
image: Path = Input( | |
description="Input face image", | |
), | |
pose_image: Path = Input( | |
description="(Optional) reference pose image", | |
default=None, | |
), | |
prompt: str = Input( | |
description="Input prompt", | |
default="a person", | |
), | |
negative_prompt: str = Input( | |
description="Input Negative Prompt", | |
default="", | |
), | |
sdxl_weights: str = Input( | |
description="Pick which base weights you want to use", | |
default="stable-diffusion-xl-base-1.0", | |
choices=[ | |
"stable-diffusion-xl-base-1.0", | |
"juggernaut-xl-v8", | |
"afrodite-xl-v2", | |
"albedobase-xl-20", | |
"albedobase-xl-v13", | |
"animagine-xl-30", | |
"anime-art-diffusion-xl", | |
"anime-illust-diffusion-xl", | |
"dreamshaper-xl", | |
"dynavision-xl-v0610", | |
"guofeng4-xl", | |
"nightvision-xl-0791", | |
"omnigen-xl", | |
"pony-diffusion-v6-xl", | |
"protovision-xl-high-fidel", | |
"RealVisXL_V3.0_Turbo", | |
"RealVisXL_V4.0_Lightning", | |
], | |
), | |
face_detection_input_width: int = Input( | |
description="Width of the input image for face detection", | |
default=640, | |
ge=640, | |
le=4096, | |
), | |
face_detection_input_height: int = Input( | |
description="Height of the input image for face detection", | |
default=640, | |
ge=640, | |
le=4096, | |
), | |
scheduler: str = Input( | |
description="Scheduler", | |
choices=[ | |
"DEISMultistepScheduler", | |
"HeunDiscreteScheduler", | |
"EulerDiscreteScheduler", | |
"DPMSolverMultistepScheduler", | |
"DPMSolverMultistepScheduler-Karras", | |
"DPMSolverMultistepScheduler-Karras-SDE", | |
], | |
default="EulerDiscreteScheduler", | |
), | |
num_inference_steps: int = Input( | |
description="Number of denoising steps", | |
default=30, | |
ge=1, | |
le=500, | |
), | |
guidance_scale: float = Input( | |
description="Scale for classifier-free guidance", | |
default=7.5, | |
ge=1, | |
le=50, | |
), | |
ip_adapter_scale: float = Input( | |
description="Scale for image adapter strength (for detail)", # adapter_strength_ratio | |
default=0.8, | |
ge=0, | |
le=1.5, | |
), | |
controlnet_conditioning_scale: float = Input( | |
description="Scale for IdentityNet strength (for fidelity)", # identitynet_strength_ratio | |
default=0.8, | |
ge=0, | |
le=1.5, | |
), | |
enable_pose_controlnet: bool = Input( | |
description="Enable Openpose ControlNet, overrides strength if set to false", | |
default=True, | |
), | |
pose_strength: float = Input( | |
description="Openpose ControlNet strength, effective only if `enable_pose_controlnet` is true", | |
default=0.4, | |
ge=0, | |
le=1, | |
), | |
enable_canny_controlnet: bool = Input( | |
description="Enable Canny ControlNet, overrides strength if set to false", | |
default=False, | |
), | |
canny_strength: float = Input( | |
description="Canny ControlNet strength, effective only if `enable_canny_controlnet` is true", | |
default=0.3, | |
ge=0, | |
le=1, | |
), | |
enable_depth_controlnet: bool = Input( | |
description="Enable Depth ControlNet, overrides strength if set to false", | |
default=False, | |
), | |
depth_strength: float = Input( | |
description="Depth ControlNet strength, effective only if `enable_depth_controlnet` is true", | |
default=0.5, | |
ge=0, | |
le=1, | |
), | |
enable_lcm: bool = Input( | |
description="Enable Fast Inference with LCM (Latent Consistency Models) - speeds up inference steps, trade-off is the quality of the generated image. Performs better with close-up portrait face images", | |
default=False, | |
), | |
lcm_num_inference_steps: int = Input( | |
description="Only used when `enable_lcm` is set to True, Number of denoising steps when using LCM", | |
default=5, | |
ge=1, | |
le=10, | |
), | |
lcm_guidance_scale: float = Input( | |
description="Only used when `enable_lcm` is set to True, Scale for classifier-free guidance when using LCM", | |
default=1.5, | |
ge=1, | |
le=20, | |
), | |
enhance_nonface_region: bool = Input( | |
description="Enhance non-face region", default=True | |
), | |
output_format: str = Input( | |
description="Format of the output images", | |
choices=["webp", "jpg", "png"], | |
default="webp", | |
), | |
output_quality: int = Input( | |
description="Quality of the output images, from 0 to 100. 100 is best quality, 0 is lowest quality.", | |
default=80, | |
ge=0, | |
le=100, | |
), | |
seed: int = Input( | |
description="Random seed. Leave blank to randomize the seed", | |
default=None, | |
), | |
num_outputs: int = Input( | |
description="Number of images to output", | |
default=1, | |
ge=1, | |
le=8, | |
), | |
disable_safety_checker: bool = Input( | |
description="Disable safety checker for generated images", | |
default=False, | |
), | |
) -> List[Path]: | |
"""Run a single prediction on the model""" | |
# If no seed is provided, generate a random seed | |
if seed is None: | |
seed = int.from_bytes(os.urandom(2), "big") | |
print(f"Using seed: {seed}") | |
# Load the weights if they are different from the base weights | |
if sdxl_weights != self.base_weights: | |
self.load_weights(sdxl_weights) | |
# Resize the output if the provided dimensions are different from the current ones | |
if self.face_detection_input_width != face_detection_input_width or self.face_detection_input_height != face_detection_input_height: | |
print(f"[!] Resizing output to {face_detection_input_width}x{face_detection_input_height}") | |
self.face_detection_input_width = face_detection_input_width | |
self.face_detection_input_height = face_detection_input_height | |
self.app.prepare(ctx_id=0, det_size=(self.face_detection_input_width, self.face_detection_input_height)) | |
# Set up ControlNet selection and their respective strength values (if any) | |
controlnet_selection = [] | |
if pose_strength > 0 and enable_pose_controlnet: | |
controlnet_selection.append("pose") | |
if canny_strength > 0 and enable_canny_controlnet: | |
controlnet_selection.append("canny") | |
if depth_strength > 0 and enable_depth_controlnet: | |
controlnet_selection.append("depth") | |
# Switch to LCM inference steps and guidance scale if LCM is enabled | |
if enable_lcm: | |
num_inference_steps = lcm_num_inference_steps | |
guidance_scale = lcm_guidance_scale | |
# Generate | |
images = self.generate_image( | |
face_image_path=str(image), | |
pose_image_path=str(pose_image) if pose_image else None, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_steps=num_inference_steps, | |
identitynet_strength_ratio=controlnet_conditioning_scale, | |
adapter_strength_ratio=ip_adapter_scale, | |
pose_strength=pose_strength, | |
canny_strength=canny_strength, | |
depth_strength=depth_strength, | |
controlnet_selection=controlnet_selection, | |
scheduler=scheduler, | |
guidance_scale=guidance_scale, | |
seed=seed, | |
enable_LCM=enable_lcm, | |
enhance_face_region=enhance_nonface_region, | |
num_images_per_prompt=num_outputs, | |
) | |
# Save the generated images and check for NSFW content | |
output_paths = [] | |
for i, output_image in enumerate(images): | |
if not disable_safety_checker: | |
_, has_nsfw_content_list = self.run_safety_checker(output_image) | |
has_nsfw_content = any(has_nsfw_content_list) | |
print(f"NSFW content detected: {has_nsfw_content}") | |
if has_nsfw_content: | |
raise Exception( | |
"NSFW content detected. Try running it again, or try a different prompt." | |
) | |
extension = output_format.lower() | |
extension = "jpeg" if extension == "jpg" else extension | |
output_path = f"/tmp/out_{i}.{extension}" | |
print(f"[~] Saving to {output_path}...") | |
print(f"[~] Output format: {extension.upper()}") | |
if output_format != "png": | |
print(f"[~] Output quality: {output_quality}") | |
save_params = {"format": extension.upper()} | |
if output_format != "png": | |
save_params["quality"] = output_quality | |
save_params["optimize"] = True | |
output_image.save(output_path, **save_params) | |
output_paths.append(Path(output_path)) | |
return output_paths | |