IDfy-Avatarifyy / cog /predict.py
yashvii's picture
Upload folder using huggingface_hub
14c8ffd verified
raw
history blame
29.6 kB
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(os.path.join(os.path.dirname(__file__), "../gradio_demo"))
import cv2
import time
import torch
import mimetypes
import subprocess
import numpy as np
from typing import List
from cog import BasePredictor, Input, Path
import PIL
from PIL import Image
import diffusers
from diffusers import LCMScheduler
from diffusers.utils import load_image
from diffusers.models import ControlNetModel
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from model_util import get_torch_device
from insightface.app import FaceAnalysis
from transformers import CLIPImageProcessor
from controlnet_util import openpose, get_depth_map, get_canny_image
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from pipeline_stable_diffusion_xl_instantid_full import (
StableDiffusionXLInstantIDPipeline,
draw_kps,
)
mimetypes.add_type("image/webp", ".webp")
# GPU global variables
DEVICE = get_torch_device()
DTYPE = torch.float16 if str(DEVICE).__contains__("cuda") else torch.float32
# for `ip-adapter`, `ControlNetModel`, and `stable-diffusion-xl-base-1.0`
CHECKPOINTS_CACHE = "./checkpoints"
CHECKPOINTS_URL = "https://weights.replicate.delivery/default/InstantID/checkpoints.tar"
# for `models/antelopev2`
MODELS_CACHE = "./models"
MODELS_URL = "https://weights.replicate.delivery/default/InstantID/models.tar"
# for the safety checker
SAFETY_CACHE = "./safety-cache"
FEATURE_EXTRACTOR = "./feature-extractor"
SAFETY_URL = "https://weights.replicate.delivery/default/playgroundai/safety-cache.tar"
SDXL_NAME_TO_PATHLIKE = {
# These are all huggingface models that we host via gcp + pget
"stable-diffusion-xl-base-1.0": {
"slug": "stabilityai/stable-diffusion-xl-base-1.0",
"url": "https://weights.replicate.delivery/default/InstantID/models--stabilityai--stable-diffusion-xl-base-1.0.tar",
"path": "checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0",
},
"afrodite-xl-v2": {
"slug": "stablediffusionapi/afrodite-xl-v2",
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--afrodite-xl-v2.tar",
"path": "checkpoints/models--stablediffusionapi--afrodite-xl-v2",
},
"albedobase-xl-20": {
"slug": "stablediffusionapi/albedobase-xl-20",
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--albedobase-xl-20.tar",
"path": "checkpoints/models--stablediffusionapi--albedobase-xl-20",
},
"albedobase-xl-v13": {
"slug": "stablediffusionapi/albedobase-xl-v13",
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--albedobase-xl-v13.tar",
"path": "checkpoints/models--stablediffusionapi--albedobase-xl-v13",
},
"animagine-xl-30": {
"slug": "stablediffusionapi/animagine-xl-30",
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--animagine-xl-30.tar",
"path": "checkpoints/models--stablediffusionapi--animagine-xl-30",
},
"anime-art-diffusion-xl": {
"slug": "stablediffusionapi/anime-art-diffusion-xl",
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--anime-art-diffusion-xl.tar",
"path": "checkpoints/models--stablediffusionapi--anime-art-diffusion-xl",
},
"anime-illust-diffusion-xl": {
"slug": "stablediffusionapi/anime-illust-diffusion-xl",
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--anime-illust-diffusion-xl.tar",
"path": "checkpoints/models--stablediffusionapi--anime-illust-diffusion-xl",
},
"dreamshaper-xl": {
"slug": "stablediffusionapi/dreamshaper-xl",
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--dreamshaper-xl.tar",
"path": "checkpoints/models--stablediffusionapi--dreamshaper-xl",
},
"dynavision-xl-v0610": {
"slug": "stablediffusionapi/dynavision-xl-v0610",
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--dynavision-xl-v0610.tar",
"path": "checkpoints/models--stablediffusionapi--dynavision-xl-v0610",
},
"guofeng4-xl": {
"slug": "stablediffusionapi/guofeng4-xl",
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--guofeng4-xl.tar",
"path": "checkpoints/models--stablediffusionapi--guofeng4-xl",
},
"juggernaut-xl-v8": {
"slug": "stablediffusionapi/juggernaut-xl-v8",
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--juggernaut-xl-v8.tar",
"path": "checkpoints/models--stablediffusionapi--juggernaut-xl-v8",
},
"nightvision-xl-0791": {
"slug": "stablediffusionapi/nightvision-xl-0791",
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--nightvision-xl-0791.tar",
"path": "checkpoints/models--stablediffusionapi--nightvision-xl-0791",
},
"omnigen-xl": {
"slug": "stablediffusionapi/omnigen-xl",
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--omnigen-xl.tar",
"path": "checkpoints/models--stablediffusionapi--omnigen-xl",
},
"pony-diffusion-v6-xl": {
"slug": "stablediffusionapi/pony-diffusion-v6-xl",
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--pony-diffusion-v6-xl.tar",
"path": "checkpoints/models--stablediffusionapi--pony-diffusion-v6-xl",
},
"protovision-xl-high-fidel": {
"slug": "stablediffusionapi/protovision-xl-high-fidel",
"url": "https://weights.replicate.delivery/default/InstantID/models--stablediffusionapi--protovision-xl-high-fidel.tar",
"path": "checkpoints/models--stablediffusionapi--protovision-xl-high-fidel",
},
"RealVisXL_V3.0_Turbo": {
"slug": "SG161222/RealVisXL_V3.0_Turbo",
"url": "https://weights.replicate.delivery/default/InstantID/models--SG161222--RealVisXL_V3.0_Turbo.tar",
"path": "checkpoints/models--SG161222--RealVisXL_V3.0_Turbo",
},
"RealVisXL_V4.0_Lightning": {
"slug": "SG161222/RealVisXL_V4.0_Lightning",
"url": "https://weights.replicate.delivery/default/InstantID/models--SG161222--RealVisXL_V4.0_Lightning.tar",
"path": "checkpoints/models--SG161222--RealVisXL_V4.0_Lightning",
},
}
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
def resize_img(
input_image,
max_side=1280,
min_side=1024,
size=None,
pad_to_max_side=False,
mode=PIL.Image.BILINEAR,
base_pixel_number=64,
):
w, h = input_image.size
if size is not None:
w_resize_new, h_resize_new = size
else:
ratio = min_side / min(h, w)
w, h = round(ratio * w), round(ratio * h)
ratio = max_side / max(h, w)
input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
input_image = input_image.resize([w_resize_new, h_resize_new], mode)
if pad_to_max_side:
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
offset_x = (max_side - w_resize_new) // 2
offset_y = (max_side - h_resize_new) // 2
res[offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new] = (
np.array(input_image)
)
input_image = Image.fromarray(res)
return input_image
def download_weights(url, dest):
start = time.time()
print("[!] Initiating download from URL: ", url)
print("[~] Destination path: ", dest)
command = ["pget", "-vf", url, dest]
if ".tar" in url:
command.append("-x")
try:
subprocess.check_call(command, close_fds=False)
except subprocess.CalledProcessError as e:
print(
f"[ERROR] Failed to download weights. Command '{' '.join(e.cmd)}' returned non-zero exit status {e.returncode}."
)
raise
print("[+] Download completed in: ", time.time() - start, "seconds")
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
if not os.path.exists(CHECKPOINTS_CACHE):
download_weights(CHECKPOINTS_URL, CHECKPOINTS_CACHE)
if not os.path.exists(MODELS_CACHE):
download_weights(MODELS_URL, MODELS_CACHE)
self.face_detection_input_width, self.face_detection_input_height = 640, 640
self.app = FaceAnalysis(
name="antelopev2",
root="./",
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
self.app.prepare(ctx_id=0, det_size=(self.face_detection_input_width, self.face_detection_input_height))
# Path to InstantID models
self.face_adapter = f"./checkpoints/ip-adapter.bin"
controlnet_path = f"./checkpoints/ControlNetModel"
# Load pipeline face ControlNetModel
self.controlnet_identitynet = ControlNetModel.from_pretrained(
controlnet_path,
torch_dtype=DTYPE,
cache_dir=CHECKPOINTS_CACHE,
local_files_only=True,
)
self.setup_extra_controlnets()
self.load_weights("stable-diffusion-xl-base-1.0")
self.setup_safety_checker()
def setup_safety_checker(self):
print(f"[~] Seting up safety checker")
if not os.path.exists(SAFETY_CACHE):
download_weights(SAFETY_URL, SAFETY_CACHE)
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
SAFETY_CACHE,
torch_dtype=DTYPE,
local_files_only=True,
)
self.safety_checker.to(DEVICE)
self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR)
def run_safety_checker(self, image):
safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(
DEVICE
)
np_image = np.array(image)
image, has_nsfw_concept = self.safety_checker(
images=[np_image],
clip_input=safety_checker_input.pixel_values.to(DTYPE),
)
return image, has_nsfw_concept
def load_weights(self, sdxl_weights):
self.base_weights = sdxl_weights
weights_info = SDXL_NAME_TO_PATHLIKE[self.base_weights]
download_url = weights_info["url"]
path_to_weights_dir = weights_info["path"]
if not os.path.exists(path_to_weights_dir):
download_weights(download_url, path_to_weights_dir)
is_hugging_face_model = "slug" in weights_info.keys()
path_to_weights_file = os.path.join(
path_to_weights_dir,
weights_info.get("file", ""),
)
print(f"[~] Loading new SDXL weights: {path_to_weights_file}")
if is_hugging_face_model:
self.pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
weights_info["slug"],
controlnet=[self.controlnet_identitynet],
torch_dtype=DTYPE,
cache_dir=CHECKPOINTS_CACHE,
local_files_only=True,
safety_checker=None,
feature_extractor=None,
)
self.pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(
self.pipe.scheduler.config
)
else: # e.g. .safetensors, NOTE: This functionality is not being used right now
self.pipe.from_single_file(
path_to_weights_file,
controlnet=self.controlnet_identitynet,
torch_dtype=DTYPE,
cache_dir=CHECKPOINTS_CACHE,
)
self.pipe.load_ip_adapter_instantid(self.face_adapter)
self.setup_lcm_lora()
self.pipe.cuda()
def setup_lcm_lora(self):
print(f"[~] Seting up LCM (just in case)")
lcm_lora_key = "models--latent-consistency--lcm-lora-sdxl"
lcm_lora_path = f"checkpoints/{lcm_lora_key}"
if not os.path.exists(lcm_lora_path):
download_weights(
f"https://weights.replicate.delivery/default/InstantID/{lcm_lora_key}.tar",
lcm_lora_path,
)
self.pipe.load_lora_weights(
"latent-consistency/lcm-lora-sdxl",
cache_dir=CHECKPOINTS_CACHE,
local_files_only=True,
weight_name="pytorch_lora_weights.safetensors",
)
self.pipe.disable_lora()
def setup_extra_controlnets(self):
print(f"[~] Seting up pose, canny, depth ControlNets")
controlnet_pose_model = "thibaud/controlnet-openpose-sdxl-1.0"
controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0"
controlnet_depth_model = "diffusers/controlnet-depth-sdxl-1.0-small"
for controlnet_key in [
"models--diffusers--controlnet-canny-sdxl-1.0",
"models--diffusers--controlnet-depth-sdxl-1.0-small",
"models--thibaud--controlnet-openpose-sdxl-1.0",
]:
controlnet_path = f"checkpoints/{controlnet_key}"
if not os.path.exists(controlnet_path):
download_weights(
f"https://weights.replicate.delivery/default/InstantID/{controlnet_key}.tar",
controlnet_path,
)
controlnet_pose = ControlNetModel.from_pretrained(
controlnet_pose_model,
torch_dtype=DTYPE,
cache_dir=CHECKPOINTS_CACHE,
local_files_only=True,
).to(DEVICE)
controlnet_canny = ControlNetModel.from_pretrained(
controlnet_canny_model,
torch_dtype=DTYPE,
cache_dir=CHECKPOINTS_CACHE,
local_files_only=True,
).to(DEVICE)
controlnet_depth = ControlNetModel.from_pretrained(
controlnet_depth_model,
torch_dtype=DTYPE,
cache_dir=CHECKPOINTS_CACHE,
local_files_only=True,
).to(DEVICE)
self.controlnet_map = {
"pose": controlnet_pose,
"canny": controlnet_canny,
"depth": controlnet_depth,
}
self.controlnet_map_fn = {
"pose": openpose,
"canny": get_canny_image,
"depth": get_depth_map,
}
def generate_image(
self,
face_image_path,
pose_image_path,
prompt,
negative_prompt,
num_steps,
identitynet_strength_ratio,
adapter_strength_ratio,
pose_strength,
canny_strength,
depth_strength,
controlnet_selection,
guidance_scale,
seed,
scheduler,
enable_LCM,
enhance_face_region,
num_images_per_prompt,
):
if enable_LCM:
self.pipe.enable_lora()
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
else:
self.pipe.disable_lora()
scheduler_class_name = scheduler.split("-")[0]
add_kwargs = {}
if len(scheduler.split("-")) > 1:
add_kwargs["use_karras_sigmas"] = True
if len(scheduler.split("-")) > 2:
add_kwargs["algorithm_type"] = "sde-dpmsolver++"
scheduler = getattr(diffusers, scheduler_class_name)
self.pipe.scheduler = scheduler.from_config(
self.pipe.scheduler.config,
**add_kwargs,
)
if face_image_path is None:
raise Exception(
f"Cannot find any input face `image`! Please upload the face `image`"
)
face_image = load_image(face_image_path)
face_image = resize_img(face_image)
face_image_cv2 = convert_from_image_to_cv2(face_image)
height, width, _ = face_image_cv2.shape
# Extract face features
face_info = self.app.get(face_image_cv2)
if len(face_info) == 0:
raise Exception(
"Face detector could not find a face in the `image`. Please use a different `image` as input."
)
face_info = sorted(
face_info,
key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1],
)[
-1
] # only use the maximum face
face_emb = face_info["embedding"]
face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])
img_controlnet = face_image
if pose_image_path is not None:
pose_image = load_image(pose_image_path)
pose_image = resize_img(pose_image, max_side=1024)
img_controlnet = pose_image
pose_image_cv2 = convert_from_image_to_cv2(pose_image)
face_info = self.app.get(pose_image_cv2)
if len(face_info) == 0:
raise Exception(
"Face detector could not find a face in the `pose_image`. Please use a different `pose_image` as input."
)
face_info = face_info[-1]
face_kps = draw_kps(pose_image, face_info["kps"])
width, height = face_kps.size
if enhance_face_region:
control_mask = np.zeros([height, width, 3])
x1, y1, x2, y2 = face_info["bbox"]
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
control_mask[y1:y2, x1:x2] = 255
control_mask = Image.fromarray(control_mask.astype(np.uint8))
else:
control_mask = None
if len(controlnet_selection) > 0:
controlnet_scales = {
"pose": pose_strength,
"canny": canny_strength,
"depth": depth_strength,
}
self.pipe.controlnet = MultiControlNetModel(
[self.controlnet_identitynet]
+ [self.controlnet_map[s] for s in controlnet_selection]
)
control_scales = [float(identitynet_strength_ratio)] + [
controlnet_scales[s] for s in controlnet_selection
]
control_images = [face_kps] + [
self.controlnet_map_fn[s](img_controlnet).resize((width, height))
for s in controlnet_selection
]
else:
self.pipe.controlnet = self.controlnet_identitynet
control_scales = float(identitynet_strength_ratio)
control_images = face_kps
generator = torch.Generator(device=DEVICE).manual_seed(seed)
print("Start inference...")
print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
self.pipe.set_ip_adapter_scale(adapter_strength_ratio)
images = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image_embeds=face_emb,
image=control_images,
control_mask=control_mask,
controlnet_conditioning_scale=control_scales,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
height=height,
width=width,
generator=generator,
num_images_per_prompt=num_images_per_prompt,
).images
return images
def predict(
self,
image: Path = Input(
description="Input face image",
),
pose_image: Path = Input(
description="(Optional) reference pose image",
default=None,
),
prompt: str = Input(
description="Input prompt",
default="a person",
),
negative_prompt: str = Input(
description="Input Negative Prompt",
default="",
),
sdxl_weights: str = Input(
description="Pick which base weights you want to use",
default="stable-diffusion-xl-base-1.0",
choices=[
"stable-diffusion-xl-base-1.0",
"juggernaut-xl-v8",
"afrodite-xl-v2",
"albedobase-xl-20",
"albedobase-xl-v13",
"animagine-xl-30",
"anime-art-diffusion-xl",
"anime-illust-diffusion-xl",
"dreamshaper-xl",
"dynavision-xl-v0610",
"guofeng4-xl",
"nightvision-xl-0791",
"omnigen-xl",
"pony-diffusion-v6-xl",
"protovision-xl-high-fidel",
"RealVisXL_V3.0_Turbo",
"RealVisXL_V4.0_Lightning",
],
),
face_detection_input_width: int = Input(
description="Width of the input image for face detection",
default=640,
ge=640,
le=4096,
),
face_detection_input_height: int = Input(
description="Height of the input image for face detection",
default=640,
ge=640,
le=4096,
),
scheduler: str = Input(
description="Scheduler",
choices=[
"DEISMultistepScheduler",
"HeunDiscreteScheduler",
"EulerDiscreteScheduler",
"DPMSolverMultistepScheduler",
"DPMSolverMultistepScheduler-Karras",
"DPMSolverMultistepScheduler-Karras-SDE",
],
default="EulerDiscreteScheduler",
),
num_inference_steps: int = Input(
description="Number of denoising steps",
default=30,
ge=1,
le=500,
),
guidance_scale: float = Input(
description="Scale for classifier-free guidance",
default=7.5,
ge=1,
le=50,
),
ip_adapter_scale: float = Input(
description="Scale for image adapter strength (for detail)", # adapter_strength_ratio
default=0.8,
ge=0,
le=1.5,
),
controlnet_conditioning_scale: float = Input(
description="Scale for IdentityNet strength (for fidelity)", # identitynet_strength_ratio
default=0.8,
ge=0,
le=1.5,
),
enable_pose_controlnet: bool = Input(
description="Enable Openpose ControlNet, overrides strength if set to false",
default=True,
),
pose_strength: float = Input(
description="Openpose ControlNet strength, effective only if `enable_pose_controlnet` is true",
default=0.4,
ge=0,
le=1,
),
enable_canny_controlnet: bool = Input(
description="Enable Canny ControlNet, overrides strength if set to false",
default=False,
),
canny_strength: float = Input(
description="Canny ControlNet strength, effective only if `enable_canny_controlnet` is true",
default=0.3,
ge=0,
le=1,
),
enable_depth_controlnet: bool = Input(
description="Enable Depth ControlNet, overrides strength if set to false",
default=False,
),
depth_strength: float = Input(
description="Depth ControlNet strength, effective only if `enable_depth_controlnet` is true",
default=0.5,
ge=0,
le=1,
),
enable_lcm: bool = Input(
description="Enable Fast Inference with LCM (Latent Consistency Models) - speeds up inference steps, trade-off is the quality of the generated image. Performs better with close-up portrait face images",
default=False,
),
lcm_num_inference_steps: int = Input(
description="Only used when `enable_lcm` is set to True, Number of denoising steps when using LCM",
default=5,
ge=1,
le=10,
),
lcm_guidance_scale: float = Input(
description="Only used when `enable_lcm` is set to True, Scale for classifier-free guidance when using LCM",
default=1.5,
ge=1,
le=20,
),
enhance_nonface_region: bool = Input(
description="Enhance non-face region", default=True
),
output_format: str = Input(
description="Format of the output images",
choices=["webp", "jpg", "png"],
default="webp",
),
output_quality: int = Input(
description="Quality of the output images, from 0 to 100. 100 is best quality, 0 is lowest quality.",
default=80,
ge=0,
le=100,
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed",
default=None,
),
num_outputs: int = Input(
description="Number of images to output",
default=1,
ge=1,
le=8,
),
disable_safety_checker: bool = Input(
description="Disable safety checker for generated images",
default=False,
),
) -> List[Path]:
"""Run a single prediction on the model"""
# If no seed is provided, generate a random seed
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
# Load the weights if they are different from the base weights
if sdxl_weights != self.base_weights:
self.load_weights(sdxl_weights)
# Resize the output if the provided dimensions are different from the current ones
if self.face_detection_input_width != face_detection_input_width or self.face_detection_input_height != face_detection_input_height:
print(f"[!] Resizing output to {face_detection_input_width}x{face_detection_input_height}")
self.face_detection_input_width = face_detection_input_width
self.face_detection_input_height = face_detection_input_height
self.app.prepare(ctx_id=0, det_size=(self.face_detection_input_width, self.face_detection_input_height))
# Set up ControlNet selection and their respective strength values (if any)
controlnet_selection = []
if pose_strength > 0 and enable_pose_controlnet:
controlnet_selection.append("pose")
if canny_strength > 0 and enable_canny_controlnet:
controlnet_selection.append("canny")
if depth_strength > 0 and enable_depth_controlnet:
controlnet_selection.append("depth")
# Switch to LCM inference steps and guidance scale if LCM is enabled
if enable_lcm:
num_inference_steps = lcm_num_inference_steps
guidance_scale = lcm_guidance_scale
# Generate
images = self.generate_image(
face_image_path=str(image),
pose_image_path=str(pose_image) if pose_image else None,
prompt=prompt,
negative_prompt=negative_prompt,
num_steps=num_inference_steps,
identitynet_strength_ratio=controlnet_conditioning_scale,
adapter_strength_ratio=ip_adapter_scale,
pose_strength=pose_strength,
canny_strength=canny_strength,
depth_strength=depth_strength,
controlnet_selection=controlnet_selection,
scheduler=scheduler,
guidance_scale=guidance_scale,
seed=seed,
enable_LCM=enable_lcm,
enhance_face_region=enhance_nonface_region,
num_images_per_prompt=num_outputs,
)
# Save the generated images and check for NSFW content
output_paths = []
for i, output_image in enumerate(images):
if not disable_safety_checker:
_, has_nsfw_content_list = self.run_safety_checker(output_image)
has_nsfw_content = any(has_nsfw_content_list)
print(f"NSFW content detected: {has_nsfw_content}")
if has_nsfw_content:
raise Exception(
"NSFW content detected. Try running it again, or try a different prompt."
)
extension = output_format.lower()
extension = "jpeg" if extension == "jpg" else extension
output_path = f"/tmp/out_{i}.{extension}"
print(f"[~] Saving to {output_path}...")
print(f"[~] Output format: {extension.upper()}")
if output_format != "png":
print(f"[~] Output quality: {output_quality}")
save_params = {"format": extension.upper()}
if output_format != "png":
save_params["quality"] = output_quality
save_params["optimize"] = True
output_image.save(output_path, **save_params)
output_paths.append(Path(output_path))
return output_paths