File size: 4,780 Bytes
01e514a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ceaed92
01e514a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0efccd
 
 
 
 
01e514a
 
 
 
 
 
 
 
 
a0efccd
 
01e514a
 
 
 
 
 
 
 
 
 
3bd6527
 
01e514a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import random
from collections import OrderedDict

import torch, os
from diffusers import FluxFillPipeline
from scepter.modules.utils.config import Config
from scepter.modules.utils.distribute import we
from scepter.modules.utils.file_system import FS
from scepter.modules.utils.logger import get_logger
from transformers import T5TokenizerFast
from .utils import ACEPlusImageProcessor

class ACEPlusDiffuserInference():
    def __init__(self, logger=None):
        if logger is None:
            logger = get_logger(name='ace_plus')
        self.logger = logger
        self.input = {}

    def load_default(self, cfg):
        if cfg is not None:
            self.input_cfg = {k.lower(): v for k, v in cfg.INPUT.items()}
            self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for k, v in cfg.INPUT.items()}
            self.output = {k.lower(): v for k, v in cfg.OUTPUT.items()}

    def init_from_cfg(self, cfg):
        self.max_seq_len = cfg.get("MAX_SEQ_LEN", 4096)
        self.image_processor = ACEPlusImageProcessor(max_seq_len=self.max_seq_len)

        local_folder = FS.get_dir_to_local_dir(cfg.MODEL.PRETRAINED_MODEL)

        self.pipe = FluxFillPipeline.from_pretrained(local_folder, torch_dtype=torch.bfloat16).to(we.device_id)

        tokenizer_2 = T5TokenizerFast.from_pretrained(os.path.join(local_folder, "tokenizer_2"),
                                                      additional_special_tokens=["{image}"])
        self.pipe.tokenizer_2 = tokenizer_2
        self.load_default(cfg.DEFAULT_PARAS)

    def prepare_input(self,
                      image,
                      mask,
                      batch_size=1,
                      dtype = torch.bfloat16,
                      num_images_per_prompt=1,
                      height=512,
                      width=512,
                      generator=None):
        num_channels_latents = self.pipe.vae.config.latent_channels
        # import pdb;pdb.set_trace()
        mask, masked_image_latents = self.pipe.prepare_mask_latents(
            mask.unsqueeze(0),
            image.unsqueeze(0).to(we.device_id, dtype = dtype),
            batch_size,
            num_channels_latents,
            num_images_per_prompt,
            height,
            width,
            dtype,
            we.device_id,
            generator,
        )
        # import pdb;pdb.set_trace()
        masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
        return masked_image_latents

    @torch.no_grad()
    def __call__(self,
                 reference_image=None,
                 edit_image=None,
                 edit_mask=None,
                 prompt='',
                 task=None,
                 output_height=1024,
                 output_width=1024,
                 sampler='flow_euler',
                 sample_steps=28,
                 guide_scale=50,
                 lora_path=None,
                 seed=-1,
                 tar_index=0,
                 align=0,
                 repainting_scale=0,
                 **kwargs):
        if isinstance(prompt, str):
            prompt = [prompt]
        seed = seed if seed >= 0 else random.randint(0, 2 ** 32 - 1)
        # edit_image, edit_mask, change_image, content_image, out_h, out_w, slice_w
        image, mask, _, _, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask,
                                                                             width = output_width,
                                                                             height = output_height,
                                                                             repainting_scale = repainting_scale)
        h, w = image.shape[1:]
        generator = torch.Generator("cpu").manual_seed(seed)
        masked_image_latents = self.prepare_input(image, mask,
                                               batch_size=len(prompt) , height=h, width=w, generator = generator)

        if lora_path is not None:
            with FS.get_from(lora_path) as local_path:
                self.pipe.load_lora_weights(local_path)



        image = self.pipe(
            prompt=prompt,
            masked_image_latents=masked_image_latents,
            height=h,
            width=w,
            guidance_scale=guide_scale,
            num_inference_steps=sample_steps,
            max_sequence_length=512,
            generator=generator
        ).images[0]
        if lora_path is not None:
            self.pipe.unload_lora_weights()
        return self.image_processor.postprocess(image, slice_w, out_w, out_h), seed


if __name__ == '__main__':
    pass