sd_card / inference_manager.py
nsfwalex's picture
Update inference_manager.py
603ed00 verified
raw
history blame
26.4 kB
import spaces
import os
import json
import time
import copy
import torch
from diffusers import AutoPipelineForText2Image, StableDiffusionPipeline,DiffusionPipeline, StableDiffusionXLPipeline, AutoencoderKL, AutoencoderTiny, UNet2DConditionModel
from huggingface_hub import hf_hub_download, snapshot_download
from pathlib import Path
from diffusers import EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSDEScheduler
from diffusers.models.attention_processor import AttnProcessor2_0
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import utils
import base64
import json
import ipown
import jwt
import glob
import traceback
from insightface.app import FaceAnalysis
import cv2
import gradio as gr
#from onediffx import compile_pipe, save_pipe, load_pipe
HF_TOKEN = os.getenv('HF_TOKEN')
VAR_PUBLIC_KEY = os.getenv('PUBLIC_KEY')
DATASET_ID = 'nsfwalex/checkpoint_n_lora'
class AuthHelper:
def load_public_key_from_file(self):
public_key_bytes = VAR_PUBLIC_KEY.encode('utf-8') # Convert to bytes if it's a string
public_key = serialization.load_pem_public_key(
public_key_bytes,
backend=default_backend()
)
return public_key
def __init__(self):
self.public_key = self.load_public_key_from_file()
# check authkey
# 1. decode with public key
# 2. check timestamp
# 3. check current host, referer, ip it should be the same as values in jwt
def decode_jwt(self, token, algorithms=["RS256"]):
"""
Decode and verify a JWT using a public key.
:param public_key: The public key used for verification.
:param token: The JWT string to decode.
:param algorithms: List of acceptable algorithms (default is ["RS256"]).
:return: The decoded JWT payload if verification is successful.
:raises: Exception if verification fails.
"""
try:
# Decode the JWT
decoded_payload = jwt.decode(
token,
self.public_key,
algorithms=algorithms,
options={"verify_signature": True} # Explicitly enable signature verification
)
return decoded_payload
except Exception as e:
print("Invalid token:", e)
raise
import hashlib
def check_auth(self, request, token):
# Extract parameters from the request
if params.get("_skip_token_passkey", "") == "nsfwaisio_125687" or not request:
return True
params = dict(request.query_params)
# Gather request-specific information
sip = request.client.host
shost = request.headers.get("Host", "")
sreferer = request.headers.get("Referer", "")
suseragent = request.headers.get("User-Agent", "")
print(sip, shost, sreferer, suseragent)
# Decode the JWT token
jwt_data = self.decode_jwt(token)
jwt_auth = jwt_data.get("auth", "")
if not jwt_auth:
raise Exception("Missing auth field in token")
# Create the MD5 hash of ip + host + referer + useragent
auth_string = f"{sip}{shost}{sreferer}{suseragent}"
calculated_md5 = hashlib.md5(auth_string.encode('utf-8')).hexdigest()
print(f"Calculated MD5: {calculated_md5}, JWT Auth: {jwt_auth}")
# Compare the calculated hash with the `auth` field from the JWT
if calculated_md5 == jwt_auth:
return True
raise Exception("Invalid authentication")
class InferenceManager:
def __init__(self, config_path="config.json", ext_model_pathes={}):
cfg = {}
with open(config_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
self.cfg = cfg
self.ext_model_pathes = ext_model_pathes
lora_options_path = cfg.get("loras", "")
self.model_version = cfg["model_version"]
self.lora_load_options = self.load_json(lora_options_path) # Load LoRA load options
self.lora_models = self.load_index_file("index.json") # Load index.json
self.preloaded_loras = [] # Array to store preloaded LoRAs with name and weights
self.ip_adapter_faceid_pipeline = None
self.base_model_pipeline = self.load_base_model() # Load the base model
self.preload_loras() # Preload LoRAs based on options
def load_json(self, filepath):
"""Load JSON file into a dictionary."""
if os.path.exists(filepath):
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
return {}
def load_index_file(self, index_file):
"""Download index.json from Hugging Face and return the file path."""
index_path = download_from_hf(index_file)
if index_path:
with open(index_path, "r", encoding="utf-8") as f:
return json.load(f)
return {}
@spaces.GPU(duration=40)
def compile_onediff(self):
self.base_model_pipeline.to("cuda")
pipe = self.base_model_pipeline
# load the compiled pipe
load_pipe(pipe, dir="cached_pipe")
print("Start oneflow compiling...")
start_compile = time.time()
pipe = compile_pipe(pipe)
# run once to trigger compilation
image = pipe(
prompt="street style, detailed, raw photo, woman, face, shot on CineStill 800T",
height=512,
width=512,
num_inference_steps=10,
output_type="pil",
).images
image[0].save(f"test_image.png")
compile_time = time.time() - start_compile
#self.base_model_pipeline.to("cpu")
# save the compiled pipe
save_pipe(pipe, dir="cached_pipe")
self.base_model_pipeline = pipe
print(f"OneDiff compile in {compile_time}s")
def load_base_model(self):
"""Load base model and return the pipeline."""
start = time.time()
cfg = self.cfg
model_version = self.model_version
ckpt_dir = snapshot_download(repo_id=cfg["model_id"], local_files_only=False)
if model_version == "1.5":
vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16)
pipe = StableDiffusionPipeline.from_pretrained(ckpt_dir, vae=vae, torch_dtype=torch.bfloat16, use_safetensors=True)
else:
use_vae = cfg.get("vae", "")
if not use_vae:
vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16)
elif use_vae == "tae":
vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.bfloat16)
else:
vae = AutoencoderTiny.from_pretrained(use_vae, torch_dtype=torch.bfloat16)
print(ckpt_dir)
pipe = DiffusionPipeline.from_pretrained(
ckpt_dir,
vae=vae,
#unet=unet,
torch_dtype=torch.bfloat16,
use_safetensors=True,
#variant="fp16",
custom_pipeline = "lpw_stable_diffusion_xl",
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
clip_skip = cfg.get("clip_skip", 1)
# Adjust clip skip for XL (assumed not relevant for SD 1.5)
pipe.text_encoder.config.num_hidden_layers -= (clip_skip - 1)
load_time = round(time.time() - start, 2)
print(f"Base model loaded in {load_time}s")
if cfg.get("load_ip_adapter_faceid", False):
if model_version in ("pony", "xl"):
ip_ckpt = self.ext_model_pathes.get("ip-adapter-faceid-sdxl", "")
if ip_ckpt:
print(f"loading ip adapter model for {model_name}")
self.ip_adapter_faceid_pipeline = ipown.IPAdapterFaceIDXL(pipe, ip_ckpt, 'cuda')
else:
print("ip-adapter-faceid-sdxl not found, skip")
return pipe
def preload_loras(self):
"""Preload all LoRAs marked as 'preload=True' and store for later use."""
for lora_name, lora_info in self.lora_load_options.items():
try:
start = time.time()
# Find the corresponding LoRA in index.json
lora_index_info = next((l for l in self.lora_models['lora'] if l['name'] == lora_name), None)
if not lora_index_info:
raise ValueError(f"LoRA {lora_name} not found in index.json.")
# Check if the LoRA base model matches the current model version
if self.model_version not in lora_info['base_model'] or not lora_info.get('preload', False):
print(f"Skipping {lora_name} as it's not compatible with the current model version.")
continue
# Load LoRA weights from the specified path
weight_path = download_from_hf(lora_index_info['path'], local_dir=None)
if not weight_path:
raise ValueError(f"Failed to download LoRA weights for {lora_name}")
load_time = round(time.time() - start, 2)
print(f"Downloaded {lora_name} in {load_time}s")
self.base_model_pipeline.load_lora_weights(
weight_path,
weight_name=lora_index_info["path"],
adapter_name=lora_name
)
# Store the preloaded LoRA name and weight for merging later
if lora_info.get("preload", False):
self.preloaded_loras.append({
"name": lora_name,
"weight": lora_info.get("weight", 1.0)
})
load_time = round(time.time() - start, 2)
print(f"Preloaded LoRA {lora_name} with weight {lora_info.get('weight', 1.0)} in {load_time}s.")
except Exception as e:
print(f"Lora {lora_name} not loaded, skipping... {e}")
def build_pipeline_with_lora(self, lora_list, sampler=None, new_pipeline=False):
"""Build the pipeline with specific LoRAs, loading any that are not preloaded."""
# Deep copy the base pipeline
start = time.time()
if new_pipeline:
temp_pipeline = copy.deepcopy(self.base_model_pipeline)
else:
temp_pipeline = self.base_model_pipeline
copy_time = round(time.time() - start, 2)
print(f"pipeline copied in {copy_time}s")
# Track LoRAs to be loaded dynamically
dynamic_loras = []
# Check if any LoRAs in lora_list need to be loaded dynamically
for lora_name in lora_list:
if not any(l['name'] == lora_name for l in self.preloaded_loras):
lora_info = next((l for l in self.lora_models['lora'] if l['name'] == lora_name), None)
if lora_info and self.model_version in lora_info["attr"].get("base_model", []):
dynamic_loras.append({
"name": lora_name,
"filename": lora_info["path"],
"scale": 1.0 # Assuming default weight as 1.0 for dynamic LoRAs
})
# Fuse preloaded and dynamic LoRAs
all_loras = [{"name": x["name"], "scale": x["weight"], "preloaded": True} for x in self.preloaded_loras] + dynamic_loras
set_lora_weights(temp_pipeline, all_loras,False)
build_time = round(time.time() - start, 2)
print(f"Pipeline built with LoRAs in {build_time}s.")
if not sampler:
sampler = self.cfg.get("sampler", "DPM2 a")
# Define samplers
samplers = {
"Euler a": EulerAncestralDiscreteScheduler.from_config(temp_pipeline.scheduler.config),
"DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(temp_pipeline.scheduler.config, use_karras_sigmas=True),
"DPM2 a": DPMSolverMultistepScheduler.from_config(temp_pipeline.scheduler.config),
"DPM++ SDE": DPMSolverSDEScheduler.from_config(temp_pipeline.scheduler.config),
"DPM++ 2M SDE": DPMSolverSDEScheduler.from_config(temp_pipeline.scheduler.config, use_2m=True),
"DPM++ 2S a": DPMSolverMultistepScheduler.from_config(temp_pipeline.scheduler.config, use_2s=True)
}
# Set the scheduler based on the selected sampler
temp_pipeline.scheduler = samplers[sampler]
# Move the final pipeline to the GPU
temp_pipeline
return temp_pipeline
def release(self, temp_pipeline):
"""Release the deepcopied pipeline to recycle memory."""
del temp_pipeline
torch.cuda.empty_cache()
print("Memory released and cache cleared.")
class ModelManager:
def __init__(self, model_directory):
"""
Initialize the ModelManager by scanning all `.model.json` files in the given directory.
:param model_directory: The directory to scan for model config files (e.g., "/path/to/models").
"""
print("downloading models...")
self.ext_model_pathes = {
"ip-adapter-faceid-sdxl": hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid_sdxl.bin", repo_type="model")
}
self.models = {}
self.ext_models = {}
self.model_directory = model_directory
self.load_models()
#not enabled at the moment
def load_instant_x(self):
#load all models
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
os.makedirs("./models",exist_ok=True)
download_from_hf("models/antelopev2/1k3d68.onnx",local_dir="./models")
download_from_hf("models/antelopev2/2d106det.onnx",local_dir="./models")
download_from_hf("models/antelopev2/genderage.onnx",local_dir="./models")
download_from_hf("models/antelopev2/glintr100.onnx",local_dir="./models")
download_from_hf("models/antelopev2/scrfd_10g_bnkps.onnx",local_dir="./models")
# prepare 'antelopev2' under ./models
app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))
# prepare models under ./checkpoints
face_adapter = f'./checkpoints/ip-adapter.bin'
controlnet_path = f'./checkpoints/ControlNetModel'
def load_models(self):
"""
Scan the model directory for `.model.json` files and initialize InferenceManager instances for each one.
:param model_directory: Directory to scan for `.model.json` files.
"""
model_files = glob.glob(os.path.join(self.model_directory, "*.model.json"))
if not model_files:
print(f"No model configuration files found in {self.model_directory}")
return
for file_path in model_files:
model_name = self.get_model_name_from_url(file_path).split(".")[0]
print(f"Initializing model: {model_name} from {file_path}")
try:
# Initialize InferenceManager for each model
self.models[model_name] = InferenceManager(config_path=file_path, ext_model_pathes=self.ext_model_pathes)
except Exception as e:
print(traceback.format_exc())
print(f"Failed to initialize model {model_name} from {file_path}: {e}")
def get_model_name_from_url(self, url):
"""
Extract the model name from the config file path (filename without extension).
:param url: The file path of the configuration file.
:return: The model name (file name without extension).
"""
filename = os.path.basename(url)
model_name, _ = os.path.splitext(filename)
return model_name
def get_model_pipeline(self, model_id, lora_list, sampler=None, new_pipeline=False):
"""
Build the pipeline with specific LoRAs for a model.
:param model_id: The model ID (the model name extracted from the config URL).
:param lora_list: List of LoRAs to be applied to the model pipeline.
:param sampler: The sampler to be used for the pipeline.
:param new_pipeline: Flag to indicate whether to create a new pipeline or reuse the existing one.
:return: The built pipeline with LoRAs applied.
"""
model = self.models.get(model_id)
if not model:
print(f"Model {model_id} not found.")
return None
try:
print(f"Building pipeline with LoRAs for model {model_id}...")
return model.build_pipeline_with_lora(lora_list, sampler, new_pipeline)
except Exception as e:
print(traceback.format_exc())
print(f"Failed to build pipeline for model {model_id}: {e}")
return None
def release_model(self, model_id):
"""
Release resources and clear memory for a specific model.
:param model_id: The model ID (the model name extracted from the config URL).
"""
model = self.models.get(model_id)
if not model:
print(f"Model {model_id} not found.")
return
try:
print(f"Releasing model {model_id}...")
model.release(model.base_model_pipeline)
except Exception as e:
print(f"Failed to release model {model_id}: {e}")
@spaces.GPU(duration=40)
def generate_with_faceid(self, model_id, inference_params, progress=gr.Progress(track_tqdm=True)):
model = self.models.get(model_id)
if not model:
raise Exception(f"invalid model_id {model_id}")
if not model.ip_adapter_faceid_pipeline:
raise Exception(f"model does not support ip adapter")
pipe = model.ip_adapter_faceid_pipeline
cfg = model.cfg
p = inference_params.get("prompt")
negative_prompt = inference_params.get("negative_prompt", cfg.get("negative_prompt", ""))
steps = inference_params.get("steps", cfg.get("inference_steps", 30))
guidance_scale = inference_params.get("guidance_scale", cfg.get("guidance_scale", 7))
width = inference_params.get("width", cfg.get("width", 512))
height = inference_params.get("height", cfg.get("height", 512))
images = inference_params.get("images", [])
likeness_strength = inference_params.get("likeness_strength", 0.4)
face_strength = inference_params.get("face_strength", 0.1)
sampler = inference_params.get("sampler", cfg.get("sampler", ""))
lora_list = inference_params.get("loras", [])
if not images:
raise Exception(f"face images not provided")
start = time.time()
pipe.to("cuda")
print("loading face analysis...")
app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(512, 512))
faceid_all_embeds = []
for image in images:
face = cv2.imread(image)
faces = app.get(face)
faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
faceid_all_embeds.append(faceid_embed)
average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
print("start inference...")
style_selection = ""
use_negative_prompt = True
randomize_seed = True
seed = seed or int(randomize_seed_fn(seed, randomize_seed))
p = remove_child_related_content(p)
prompt_str = cfg.get("prompt", "{prompt}").replace("{prompt}", p)
generator = torch.Generator(pipe.device).manual_seed(seed)
print(f"generate: p={p}, np={np}, steps={steps}, guidance_scale={guidance_scale}, size={width},{height}, seed={seed}")
images = pipe(
prompt=prompt_str,
negative_prompt=negative_prompt,
faceid_embeds=average_embedding,
scale=likeness_strength,
width=width,
height=height,
guidance_scale=face_strength,
num_inference_steps=steps,
generator=generator,
num_images_per_prompt=1,
output_type="pil",
#callback_on_step_end=callback_dynamic_cfg,
#callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids'],
).images
cost = round(time.time() - start, 2)
print(f"inference done in {cost}s")
images = [save_image(img) for img in images]
image_paths = [i[1] for i in images]
print(prompt_str, image_paths)
return [i[0] for i in images]
@spaces.GPU(duration=40)
def generate(self, model_id, inference_params, progress=gr.Progress(track_tqdm=True)):
def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
cfg_disabling_at = cfg.get('cfg_disabling_rate', 0.75)
if step_index == int(pipe.num_timesteps * cfg_disabling_at):
callback_kwargs['prompt_embeds'] = callback_kwargs['prompt_embeds'].chunk(2)[-1]
callback_kwargs['add_text_embeds'] = callback_kwargs['add_text_embeds'].chunk(2)[-1]
callback_kwargs['add_time_ids'] = callback_kwargs['add_time_ids'].chunk(2)[-1]
pipe._guidance_scale = 0.0
return callback_kwargs
model = self.models.get(model_id)
if not model:
raise Exception(f"invalid model_id {model_id}")
if not model.ip_adapter_faceid_pipeline:
raise Exception(f"model does not support ip adapter")
cfg = model.cfg
p = inference_params.get("prompt")
negative_prompt = inference_params.get("negative_prompt", cfg.get("negative_prompt", ""))
inference_steps = inference_params.get("steps", cfg.get("inference_steps", 30))
guidance_scale = inference_params.get("guidance_scale", cfg.get("guidance_scale", 7))
width = inference_params.get("width", cfg.get("width", 512))
height = inference_params.get("height", cfg.get("height", 512))
sampler = inference_params.get("sampler", cfg.get("sampler", ""))
lora_list = inference_params.get("loras", [])
pipe = model.build_pipeline_with_lora(lora_list, sampler, lora_list)
start = time.time()
pipe.to("cuda")
print("start inference...")
style_selection = ""
use_negative_prompt = True
randomize_seed = True
seed = seed or int(randomize_seed_fn(seed, randomize_seed))
guidance_scale = guidance_scale or cfg.get("guidance_scale", 7.5)
p = remove_child_related_content(p)
prompt_str = cfg.get("prompt", "{prompt}").replace("{prompt}", p)
generator = torch.Generator(pipe.device).manual_seed(seed)
print(f"generate: p={p}, np={np}, steps={steps}, guidance_scale={guidance_scale}, size={width},{height}, seed={seed}")
images = pipe(
prompt=prompt_str,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=inference_steps,
generator=generator,
num_images_per_prompt=1,
output_type="pil",
callback_on_step_end=callback_dynamic_cfg,
callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids'],
).images
cost = round(time.time() - start, 2)
print(f"inference done in {cost}s")
images = [save_image(img) for img in images]
image_paths = [i[1] for i in images]
print(prompt_str, image_paths)
return [i[0] for i in images]
# Hugging Face file download function - returns only file path
def download_from_hf(filename, local_dir=None, repo_id=DATASET_ID, repo_type="dataset"):
try:
file_path = hf_hub_download(
filename=filename,
repo_id=DATASET_ID,
repo_type="dataset",
revision="main",
local_dir=local_dir,
local_files_only=False, # Attempt to load from cache if available
)
return file_path # Return file path only
except Exception as e:
print(f"Failed to load {filename} from Hugging Face: {str(e)}")
return None
# Function to load and fuse LoRAs
def set_lora_weights(pipe, lorajson: list[dict], fuse=False):
try:
if not lorajson or not isinstance(lorajson, list):
return
a_list = []
w_list = []
for d in lorajson:
if not d or not isinstance(d, dict) or not d["name"] or d["name"] == "None":
continue
k = d["name"]
if not d.get("preloaded", False):
start = time.time()
weight_path = download_from_hf(d['filename'], local_dir=None)
if weight_path:
pipe.load_lora_weights(weight_path, weight_name=d['filename'], adapter_name=k)
load_time = round(time.time() - start, 2)
print(f"LoRA {k} loaded in {load_time}s.")
a_list.append(k)
w_list.append(d["scale"])
if not a_list:
return
start = time.time()
pipe.set_adapters(a_list, adapter_weights=w_list)
if fuse:
pipe.fuse_lora(adapter_names=a_list, lora_scale=1.0)
fuse_time = round(time.time() - start, 2)
print(f"LoRAs fused in {fuse_time}s.")
except Exception as e:
print(f"External LoRA Error: {e}")
raise Exception(f"External LoRA Error: {e}") from e