chaojiemao commited on
Commit
01e514a
·
verified ·
1 Parent(s): aac0117

Upload 7 files

Browse files
config/ace_plus_diffusers_infer.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NAME: ace_plus_diffuser_infer
2
+ IS_DEFAULT: True
3
+ USE_DYNAMIC_MODEL: False
4
+ INFERENCE_TYPE: ACE_DIFFUSER_PLUS
5
+ DEFAULT_PARAS:
6
+ PARAS:
7
+ #
8
+ INPUT:
9
+ INPUT_IMAGE:
10
+ INPUT_MASK:
11
+ TASK:
12
+ PROMPT: ""
13
+ OUTPUT_HEIGHT: 1024
14
+ OUTPUT_WIDTH: 1024
15
+ SAMPLER: flow_euler
16
+ SAMPLE_STEPS: 28
17
+ GUIDE_SCALE: 50
18
+ SEED: 42
19
+ MAX_SEQ_LENGTH: 4096
20
+ OUTPUT:
21
+ LATENT:
22
+ IMAGES:
23
+ SEED:
24
+ MODEL:
25
+ PRETRAINED_MODEL: ${FLUX_FILL_PATH}
examples/__init__.py ADDED
File without changes
examples/examples.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ all_examples = [
2
+ {
3
+ "input_image": None,
4
+ "input_mask": None,
5
+ "input_reference_image": "assets/samples/portrait/human_1.jpg",
6
+ "save_path": "examples/outputs/portrait_human_1.jpg",
7
+ "instruction": "Maintain the facial features, A girl is wearing a neat police uniform and sporting a badge. She is smiling with a friendly and confident demeanor. The background is blurred, featuring a cartoon logo.",
8
+ "output_h": 1024,
9
+ "output_w": 1024,
10
+ "seed": 4194866942,
11
+ "repainting_scale": 1.0,
12
+ "task_type": "portrait",
13
+ "edit_type": "repainting"
14
+ },
15
+ {
16
+ "input_image": None,
17
+ "input_mask": None,
18
+ "input_reference_image": "assets/samples/subject/subject_1.jpg",
19
+ "save_path": "examples/outputs/subject_subject_1.jpg",
20
+ "instruction": "Display the logo in a minimalist style printed in white on a matte black ceramic coffee mug, alongside a steaming cup of coffee on a cozy cafe table.",
21
+ "output_h": 1024,
22
+ "output_w": 1024,
23
+ "seed": 2935362780,
24
+ "repainting_scale": 1.0,
25
+ "task_type": "subject",
26
+ "edit_type": "repainting"
27
+ },
28
+ {
29
+ "input_image": "assets/samples/local/local_1.webp",
30
+ "input_mask": "assets/samples/local/local_1_m.webp",
31
+ "input_reference_image": None,
32
+ "save_path": "examples/outputs/local_local_1.jpg",
33
+ "instruction": "By referencing the mask, restore a partial image from the doodle {image} that aligns with the textual explanation: \"1 white old owl\".",
34
+ "output_h": -1,
35
+ "output_w": -1,
36
+ "seed": 1159797084,
37
+ "repainting_scale": 0.5,
38
+ "task_type": "local_editing",
39
+ "edit_type": "contour_repainting"
40
+ },
41
+ {
42
+ "input_image": "assets/samples/application/photo_editing/1_1_edit.png",
43
+ "input_mask": "assets/samples/application/photo_editing/1_1_m.png",
44
+ "input_reference_image": "assets/samples/application/photo_editing/1_ref.png",
45
+ "save_path": "examples/outputs/photo_editing_1.jpg",
46
+ "instruction": "The item is put on the ground.",
47
+ "output_h": -1,
48
+ "output_w": -1,
49
+ "seed": 2072028954,
50
+ "repainting_scale": 1.0,
51
+ "task_type": "subject",
52
+ "edit_type": "repainting"
53
+ },
54
+ {
55
+ "input_image": "assets/samples/application/logo_paste/1_1_edit.png",
56
+ "input_mask": "assets/samples/application/logo_paste/1_1_m.png",
57
+ "input_reference_image": "assets/samples/application/logo_paste/1_ref.png",
58
+ "save_path": "examples/outputs/logo_paste_1.jpg",
59
+ "instruction": "The logo is printed on the headphones.",
60
+ "output_h": -1,
61
+ "output_w": -1,
62
+ "seed": 934582264,
63
+ "repainting_scale": 1.0,
64
+ "task_type": "subject",
65
+ "edit_type": "repainting"
66
+ },
67
+ {
68
+ "input_image": "assets/samples/application/movie_poster/1_1_edit.png",
69
+ "input_mask": "assets/samples/application/movie_poster/1_1_m.png",
70
+ "input_reference_image": "assets/samples/application/movie_poster/1_ref.png",
71
+ "save_path": "examples/outputs/movie_poster_1.jpg",
72
+ "instruction": "The man is facing the camera and is smiling.",
73
+ "output_h": -1,
74
+ "output_w": -1,
75
+ "seed": 988183236,
76
+ "repainting_scale": 1.0,
77
+ "task_type": "portrait",
78
+ "edit_type": "repainting"
79
+ }
80
+
81
+ ]
inference/__init__.py ADDED
File without changes
inference/ace_plus_diffusers.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import random
4
+ from collections import OrderedDict
5
+
6
+ import torch, os
7
+ from diffusers import FluxFillPipeline
8
+ from scepter.modules.utils.config import Config
9
+ from scepter.modules.utils.distribute import we
10
+ from scepter.modules.utils.file_system import FS
11
+ from scepter.modules.utils.logger import get_logger
12
+ from transformers import T5TokenizerFast
13
+ from .utils import ACEPlusImageProcessor
14
+
15
+
16
+ class ACEPlusDiffuserInference():
17
+ def __init__(self, logger=None):
18
+ if logger is None:
19
+ logger = get_logger(name='ace_plus')
20
+ self.logger = logger
21
+ self.input = {}
22
+
23
+ def load_default(self, cfg):
24
+ if cfg is not None:
25
+ self.input_cfg = {k.lower(): v for k, v in cfg.INPUT.items()}
26
+ self.input = {k.lower(): dict(v).get('DEFAULT', None) if isinstance(v, (dict, OrderedDict, Config)) else v for k, v in cfg.INPUT.items()}
27
+ self.output = {k.lower(): v for k, v in cfg.OUTPUT.items()}
28
+
29
+ def init_from_cfg(self, cfg):
30
+ self.max_seq_len = cfg.get("MAX_SEQ_LEN", 4096)
31
+ self.image_processor = ACEPlusImageProcessor(max_seq_len=self.max_seq_len)
32
+
33
+ local_folder = FS.get_dir_to_local_dir(cfg.MODEL.PRETRAINED_MODEL)
34
+
35
+ self.pipe = FluxFillPipeline.from_pretrained(local_folder, torch_dtype=torch.bfloat16).to("cuda")
36
+
37
+ tokenizer_2 = T5TokenizerFast.from_pretrained(os.path.join(local_folder, "tokenizer_2"),
38
+ additional_special_tokens=["{image}"])
39
+ self.pipe.tokenizer_2 = tokenizer_2
40
+ self.load_default(cfg.DEFAULT_PARAS)
41
+
42
+
43
+ def prepare_input(self,
44
+ image,
45
+ mask,
46
+ batch_size=1,
47
+ dtype = torch.bfloat16,
48
+ num_images_per_prompt=1,
49
+ height=512,
50
+ width=512,
51
+ generator=None):
52
+ num_channels_latents = self.pipe.vae.config.latent_channels
53
+ # import pdb;pdb.set_trace()
54
+ mask, masked_image_latents = self.pipe.prepare_mask_latents(
55
+ mask.unsqueeze(0),
56
+ image.unsqueeze(0).to(we.device_id, dtype = dtype),
57
+ batch_size,
58
+ num_channels_latents,
59
+ num_images_per_prompt,
60
+ height,
61
+ width,
62
+ dtype,
63
+ we.device_id,
64
+ generator,
65
+ )
66
+ # import pdb;pdb.set_trace()
67
+ masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
68
+ return masked_image_latents
69
+
70
+ @torch.no_grad()
71
+ def __call__(self,
72
+ reference_image=None,
73
+ edit_image=None,
74
+ edit_mask=None,
75
+ prompt='',
76
+ task=None,
77
+ output_height=1024,
78
+ output_width=1024,
79
+ sampler='flow_euler',
80
+ sample_steps=28,
81
+ guide_scale=50,
82
+ lora_path=None,
83
+ seed=-1,
84
+ tar_index=0,
85
+ align=0,
86
+ repainting_scale=0,
87
+ **kwargs):
88
+ if isinstance(prompt, str):
89
+ prompt = [prompt]
90
+ seed = seed if seed >= 0 else random.randint(0, 2 ** 32 - 1)
91
+ image, mask, out_h, out_w, slice_w = self.image_processor.preprocess(reference_image, edit_image, edit_mask, repainting_scale = repainting_scale)
92
+ h, w = image.shape[1:]
93
+ generator = torch.Generator("cpu").manual_seed(seed)
94
+ masked_image_latents = self.prepare_input(image, mask,
95
+ batch_size=len(prompt) , height=h, width=w, generator = generator)
96
+
97
+ if lora_path is not None:
98
+ with FS.get_from(lora_path) as local_path:
99
+ self.pipe.load_lora_weights(local_path)
100
+
101
+ image = self.pipe(
102
+ prompt=prompt,
103
+ masked_image_latents=masked_image_latents,
104
+ height=h,
105
+ width=w,
106
+ guidance_scale=guide_scale,
107
+ num_inference_steps=sample_steps,
108
+ max_sequence_length=512,
109
+ generator=generator
110
+ ).images[0]
111
+ return self.image_processor.postprocess(image, slice_w, out_w, out_h), seed
112
+
113
+
114
+ if __name__ == '__main__':
115
+ pass
inference/utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import math
4
+
5
+ import torch
6
+ import torchvision.transforms as T
7
+ import numpy as np
8
+ from scepter.modules.annotator.registry import ANNOTATORS
9
+ from scepter.modules.utils.config import Config
10
+ from PIL import Image
11
+
12
+
13
+ def edit_preprocess(processor, device, edit_image, edit_mask):
14
+ if edit_image is None or processor is None:
15
+ return edit_image
16
+ processor = Config(cfg_dict=processor, load=False)
17
+ processor = ANNOTATORS.build(processor).to(device)
18
+ new_edit_image = processor(np.asarray(edit_image))
19
+ processor = processor.to("cpu")
20
+ del processor
21
+ new_edit_image = Image.fromarray(new_edit_image)
22
+ return Image.composite(new_edit_image, edit_image, edit_mask)
23
+
24
+ class ACEPlusImageProcessor():
25
+ def __init__(self, max_aspect_ratio=4, d=16, max_seq_len=1024):
26
+ self.max_aspect_ratio = max_aspect_ratio
27
+ self.d = d
28
+ self.max_seq_len = max_seq_len
29
+ self.transforms = T.Compose([
30
+ T.ToTensor(),
31
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
32
+ ])
33
+
34
+ def image_check(self, image):
35
+ if image is None:
36
+ return image
37
+ # preprocess
38
+ W, H = image.size
39
+ if H / W > self.max_aspect_ratio:
40
+ image = T.CenterCrop([int(self.max_aspect_ratio * W), W])(image)
41
+ elif W / H > self.max_aspect_ratio:
42
+ image = T.CenterCrop([H, int(self.max_aspect_ratio * H)])(image)
43
+ return self.transforms(image)
44
+
45
+
46
+ def preprocess(self,
47
+ reference_image=None,
48
+ edit_image=None,
49
+ edit_mask=None,
50
+ height=1024,
51
+ width=1024,
52
+ repainting_scale = 1.0):
53
+ reference_image = self.image_check(reference_image)
54
+ edit_image = self.image_check(edit_image)
55
+ # for reference generation
56
+ if edit_image is None:
57
+ edit_image = torch.zeros([3, height, width])
58
+ edit_mask = torch.ones([1, height, width])
59
+ else:
60
+ edit_mask = np.asarray(edit_mask)
61
+ edit_mask = np.where(edit_mask > 128, 1, 0)
62
+ edit_mask = edit_mask.astype(
63
+ np.float32) if np.any(edit_mask) else np.ones_like(edit_mask).astype(
64
+ np.float32)
65
+ edit_mask = torch.tensor(edit_mask).unsqueeze(0)
66
+
67
+ edit_image = edit_image * (1 - edit_mask * repainting_scale)
68
+
69
+
70
+ out_h, out_w = edit_image.shape[-2:]
71
+
72
+ assert edit_mask is not None
73
+ if reference_image is not None:
74
+ # align height with edit_image
75
+ _, H, W = reference_image.shape
76
+ _, eH, eW = edit_image.shape
77
+ scale = eH / H
78
+ tH, tW = eH, int(W * scale)
79
+ reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(reference_image)
80
+ edit_image = torch.cat([reference_image, edit_image], dim=-1)
81
+ edit_mask = torch.cat([torch.zeros([1, reference_image.shape[1], reference_image.shape[2]]), edit_mask], dim=-1)
82
+ slice_w = reference_image.shape[-1]
83
+ else:
84
+ slice_w = 0
85
+
86
+ H, W = edit_image.shape[-2:]
87
+ scale = min(1.0, math.sqrt(self.max_seq_len * 2 / ((H / self.d) * (W / self.d))))
88
+ rH = int(H * scale) // self.d * self.d # ensure divisible by self.d
89
+ rW = int(W * scale) // self.d * self.d
90
+ slice_w = int(slice_w * scale) // self.d * self.d
91
+
92
+ edit_image = T.Resize((rH, rW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(edit_image)
93
+ edit_mask = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_mask)
94
+
95
+ return edit_image, edit_mask, out_h, out_w, slice_w
96
+
97
+
98
+ def postprocess(self, image, slice_w, out_w, out_h):
99
+ w, h = image.size
100
+ if slice_w > 0:
101
+ output_image = image.crop((slice_w + 20, 0, w, h))
102
+ output_image = output_image.resize((out_w, out_h))
103
+ else:
104
+ output_image = image
105
+ return output_image
models/model_zoo.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ PORTRAIT:
3
+ MODEL_PATH: ${PORTRAIT_MODEL_PATH}
4
+ SUBJECT:
5
+ MODEL_PATH: ${SUBJECT_MODEL_PATH}
6
+ LOCAL_EDITING:
7
+ MODEL_PATH: ${LOCAL_MODEL_PATH}
8
+ REPAINTING_SCALE: 0.5
9
+ PREPROCESSOR:
10
+ - NAME: CannyAnnotator
11
+ TYPE: canny_repainting
12
+ LOW_THRESHOLD: 100
13
+ HIGH_THRESHOLD: 200
14
+ - NAME: ColorAnnotator
15
+ TYPE: mosaic_repainting
16
+ RATIO: 64
17
+ - NAME: InfoDrawContourAnnotator
18
+ TYPE: contour_repainting
19
+ INPUT_NC: 3
20
+ OUTPUT_NC: 1
21
+ N_RESIDUAL_BLOCKS: 3
22
+ SIGMOID: True
23
+ PRETRAINED_MODEL: "ms://iic/scepter_annotator@annotator/ckpts/informative_drawing_contour_style.pth"
24
+ - NAME: MidasDetector
25
+ PRETRAINED_MODEL: "ms://iic/scepter_annotator@annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
26
+ TYPE: depth_repainting
27
+ - NAME: GrayAnnotator
28
+ TYPE: recolorizing