|
--- |
|
license: mit |
|
--- |
|
|
|
|
|
|
|
Usage InstanID-XS: |
|
|
|
# 1.Download model . |
|
|
|
```bash |
|
# InstanID-XS |
|
huggingface-cli download --resume-download RED-AIGC/InstantID-XS --local-dir ./checkpoints |
|
# vae: madebyollin/sdxl-vae-fp16-fix |
|
huggingface-cli download --resume-download madebyollin/sdxl-vae-fp16-fix --local-dir ./checkpoints |
|
# base model: RealVisXL V4.0 |
|
huggingface-cli download --resume-download frankjoshua/realvisxlV40_v40Bakedvae --local-dir ./checkpoints |
|
``` |
|
|
|
|
|
|
|
# 2.Get pipeline |
|
|
|
|
|
|
|
Note: In ControlNetXS, the input of encoder_hidden_states in the controlnet part is the same as that of UNET by default, which is prompt-embeddings. We decouple the inputs of the two, so that the input of encoder_hidden_states in UNET is prompt-embeddings, while the input of encoder_hidden_states in the controlnet part is face-embeddings. |
|
|
|
```python |
|
from diffusers import AutoencoderKL,UNet2DConditionModel,UniPCMultistepScheduler |
|
from controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel |
|
import torch |
|
|
|
|
|
base_model = './checkpoints/frankjoshua/realvisxlV40_v40Bakedvae' |
|
vae_path = './checkpoints/madebyollin/sdxl-vae-fp16-fix' |
|
ckpt = './checkpoints/RED-AIGC/InstantID-XS' |
|
|
|
image_proj_path = os.path.join(ckpt, "image_proj.bin") |
|
cnxs_path = os.path.join(ckpt, "controlnetxs.bin") |
|
cross_attn_path = os.path.join(ckpt, "cross_attn.bin") |
|
|
|
# Get ControlNetXS: |
|
unet = UNet2DConditionModel.from_pretrained(base_model, subfolder="unet").to(device, dtype=weight_dtype) |
|
controlnet = ControlNetXSAdapter.from_unet(unet, size_ratio=0.125, learn_time_embedding=True) |
|
state_dict = torch.load(cnxs_path, map_location="cpu", weights_only=True) |
|
ctrl_state_dict = {} |
|
for key, value in state_dict.items(): |
|
if 'ctrl_' in key and 'ctrl_to_base' not in key: |
|
key = key.replace('ctrl_', '') |
|
if 'up_blocks' in key: |
|
key = key.replace('up_blocks', 'up_connections') |
|
ctrl_state_dict[key] = value |
|
controlnet.load_state_dict(ctrl_state_dict, strict=True) |
|
controlnet.to(device, dtype=weight_dtype) |
|
ControlNetXS = UNetControlNetXSModel.from_unet(unet, controlnet).to(device, dtype=weight_dtype) |
|
|
|
|
|
# Get pipeline |
|
vae = AutoencoderKL.from_pretrained(vae_model) |
|
|
|
pipe = StableDiffusionXLInstantIDXSPipeline.from_pretrained( |
|
base_model, |
|
vae=vae, |
|
unet=ControlNetXS, |
|
controlnet=None, |
|
torch_dtype=weight_dtype, |
|
) |
|
|
|
pipe.cuda(device=device, dtype=weight_dtype, use_xformers=True) |
|
pipe.load_ip_adapter(image_proj_path, cross_attn_path) |
|
|
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
|
pipe.unet.config.ctrl_learn_time_embedding = True |
|
pipe = pipe.to(device) |
|
|
|
``` |
|
|
|
|
|
|
|
# 3.Infer: |
|
|
|
```python |
|
import cv2 |
|
import os |
|
from PIL import Image |
|
from insightface.app import FaceAnalysis |
|
|
|
app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) |
|
app.prepare(ctx_id=0, det_size=(640, 640)) |
|
|
|
|
|
img_path = './image.jpg' |
|
image = cv2.imread(img_path) |
|
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) |
|
image = resize_img(image) |
|
|
|
face_infos = app.get(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)) |
|
face_info = sorted(face_infos, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] |
|
face_emb = torch.from_numpy(face_info.normed_embedding) |
|
face_kps = draw_kps_pil(image, face_info['kps']) |
|
|
|
prompt = 'a woman, (looking at the viewer), portrait, daily wear, 8K texture, realistic, symmetrical hyperdetailed texture, masterpiece, enhanced details, (eye highlight:2), perfect composition, natural lighting, best quality, authentic, natural posture' |
|
n_prompt = '(worst quality:2), (low quality:2), (normal quality:2), lowres, bad anatomy, bad hands, normal quality, long neck, hunchback, narrow shoulder, wall, (blurry), vague, indistinct, (shiny face:2), (buffing:2), (face highlight:2), pale skin' |
|
|
|
seed = 0 |
|
image = pipe( |
|
prompt=prompt, |
|
negative_prompt=n_prompt, |
|
image=face_kps, |
|
face_emb=face_emb, |
|
num_images_per_prompt=1, |
|
num_inference_steps=20, |
|
generator=torch.Generator(device=device).manual_seed(seed), |
|
ip_adapter_scale=0.8, |
|
guidance_scale=4.0, |
|
controlnet_conditioning_scale=0.8, |
|
).images[0] |
|
``` |
|
|
|
|