Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
# Copyright (c) Alibaba, Inc. and its affiliates. | |
import random | |
from collections import OrderedDict | |
import torch, numpy as np | |
from PIL import Image | |
from scepter.modules.model.registry import MODELS | |
from scepter.modules.utils.config import Config | |
from scepter.modules.utils.distribute import we | |
from .registry import BaseInference, INFERENCES | |
from .utils import ACEPlusImageProcessor | |
class ACEInference(BaseInference): | |
''' | |
reuse the ldm code | |
''' | |
def __init__(self, cfg, logger=None): | |
super().__init__(cfg, logger) | |
self.pipe = MODELS.build(cfg.MODEL, logger=self.logger).eval().to(we.device_id) | |
self.image_processor = ACEPlusImageProcessor(max_seq_len=cfg.MAX_SEQ_LEN) | |
self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for | |
k, v in cfg.SAMPLE_ARGS.items()} | |
self.dtype = getattr(torch, cfg.get("DTYPE", "bfloat16")) | |
def __call__(self, | |
reference_image=None, | |
edit_image=None, | |
edit_mask=None, | |
prompt='', | |
edit_type=None, | |
output_height=1024, | |
output_width=1024, | |
sampler='flow_euler', | |
sample_steps=28, | |
guide_scale=50, | |
lora_path=None, | |
seed=-1, | |
repainting_scale=0, | |
use_change=False, | |
keep_pixels=False, | |
keep_pixels_rate=0.8, | |
**kwargs): | |
# convert the input info to the input of ldm. | |
if isinstance(prompt, str): | |
prompt = [prompt] | |
seed = seed if seed >= 0 else random.randint(0, 2 ** 24 - 1) | |
image, mask, change_image, content_image, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask, | |
height=output_height, width=output_width, | |
repainting_scale=repainting_scale, | |
keep_pixels=keep_pixels, | |
keep_pixels_rate=keep_pixels_rate, | |
use_change = use_change) | |
change_image = [None] if change_image is None else [change_image.to(we.device_id)] | |
image, mask = [image.to(we.device_id)], [mask.to(we.device_id)] | |
(src_image_list, src_mask_list, modify_image_list, | |
edit_id, prompt) = [image], [mask], [change_image], [[0]], [prompt] | |
with torch.amp.autocast(enabled=True, dtype=self.dtype, device_type='cuda'): | |
out_image = self.pipe( | |
src_image_list=src_image_list, | |
modify_image_list= modify_image_list, | |
src_mask_list=src_mask_list, | |
edit_id=edit_id, | |
image=image, | |
image_mask=mask, | |
prompt=prompt, | |
sampler='flow_euler', | |
sample_steps=sample_steps, | |
seed=seed, | |
guide_scale=guide_scale, | |
show_process=True, | |
) | |
imgs = [x_i['reconstruct_image'].float().permute(1, 2, 0).cpu().numpy() | |
for x_i in out_image | |
] | |
imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs] | |
edit_image = Image.fromarray((torch.clamp(image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8)) | |
change_image = Image.fromarray((torch.clamp(change_image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8)) | |
mask = Image.fromarray((mask[0] * 255).squeeze(0).cpu().numpy().astype(np.uint8)) | |
return self.image_processor.postprocess(imgs[0], slice_w, out_w, out_h), edit_image, change_image, mask, seed | |