File size: 4,122 Bytes
a0efccd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import random
from collections import OrderedDict

import torch, numpy as np
from PIL import Image
from scepter.modules.model.registry import MODELS
from scepter.modules.utils.config import Config
from scepter.modules.utils.distribute import we
from .registry import BaseInference, INFERENCES
from .utils import ACEPlusImageProcessor

@INFERENCES.register_class()
class ACEInference(BaseInference):
    '''
        reuse the ldm code
    '''
    def __init__(self, cfg, logger=None):
        super().__init__(cfg, logger)
        self.pipe = MODELS.build(cfg.MODEL, logger=self.logger).eval().to(we.device_id)
        self.image_processor = ACEPlusImageProcessor(max_seq_len=cfg.MAX_SEQ_LEN)
        self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for
                      k, v in cfg.SAMPLE_ARGS.items()}
        self.dtype = getattr(torch, cfg.get("DTYPE", "bfloat16"))
    @torch.no_grad()
    def __call__(self,
                 reference_image=None,
                 edit_image=None,
                 edit_mask=None,
                 prompt='',
                 edit_type=None,
                 output_height=1024,
                 output_width=1024,
                 sampler='flow_euler',
                 sample_steps=28,
                 guide_scale=50,
                 lora_path=None,
                 seed=-1,
                 repainting_scale=0,
                 use_change=False,
                 keep_pixels=False,
                 keep_pixels_rate=0.8,
                 **kwargs):
        # convert the input info to the input of ldm.
        if isinstance(prompt, str):
            prompt = [prompt]
        seed = seed if seed >= 0 else random.randint(0, 2 ** 24 - 1)
        image, mask, change_image, content_image, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask,
                                                                             height=output_height, width=output_width,
                                                                             repainting_scale=repainting_scale,
                                                                             keep_pixels=keep_pixels,
                                                                             keep_pixels_rate=keep_pixels_rate,
                                                                             use_change = use_change)
        change_image = [None] if change_image is None else [change_image.to(we.device_id)]
        image, mask = [image.to(we.device_id)], [mask.to(we.device_id)]

        (src_image_list, src_mask_list, modify_image_list,
         edit_id, prompt) = [image], [mask], [change_image], [[0]], [prompt]

        with torch.amp.autocast(enabled=True, dtype=self.dtype, device_type='cuda'):
            out_image = self.pipe(
                src_image_list=src_image_list,
                modify_image_list= modify_image_list,
                src_mask_list=src_mask_list,
                edit_id=edit_id,
                image=image,
                image_mask=mask,
                prompt=prompt,
                sampler='flow_euler',
                sample_steps=sample_steps,
                seed=seed,
                guide_scale=guide_scale,
                show_process=True,
            )
        imgs = [x_i['reconstruct_image'].float().permute(1, 2, 0).cpu().numpy()
            for x_i in out_image
        ]
        imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
        edit_image = Image.fromarray((torch.clamp(image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
        change_image = Image.fromarray((torch.clamp(change_image[0] / 2 + 0.5, min=0.0, max=1.0)*255).float().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
        mask = Image.fromarray((mask[0] * 255).squeeze(0).cpu().numpy().astype(np.uint8))
        return self.image_processor.postprocess(imgs[0], slice_w, out_w, out_h), edit_image, change_image, mask, seed