Bingsu commited on
Commit
cd267d9
1 Parent(s): 39a2a92

Upload files: v0.2.1

Browse files
Files changed (7) hide show
  1. asdff/__init__.py +10 -0
  2. asdff/__version__.py +1 -0
  3. asdff/base.py +174 -0
  4. asdff/sd.py +51 -0
  5. asdff/utils.py +70 -0
  6. asdff/yolo.py +80 -0
  7. pipeline.py +1 -0
asdff/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .__version__ import __version__
2
+ from .sd import AdCnPipeline, AdPipeline
3
+ from .yolo import yolo_detector
4
+
5
+ __all__ = [
6
+ "AdPipeline",
7
+ "AdCnPipeline",
8
+ "yolo_detector",
9
+ "__version__",
10
+ ]
asdff/__version__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.2.1"
asdff/base.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from abc import ABC, abstractmethod
5
+ from typing import Any, Callable, Iterable, List, Mapping, Optional
6
+
7
+ from diffusers.utils import logging
8
+ from PIL import Image
9
+
10
+ from asdff.utils import (
11
+ ADOutput,
12
+ bbox_padding,
13
+ composite,
14
+ mask_dilate,
15
+ mask_gaussian_blur,
16
+ )
17
+ from asdff.yolo import yolo_detector
18
+
19
+ logger = logging.get_logger("diffusers")
20
+
21
+
22
+ DetectorType = Callable[[Image.Image], Optional[List[Image.Image]]]
23
+
24
+
25
+ def ordinal(n: int) -> str:
26
+ d = {1: "st", 2: "nd", 3: "rd"}
27
+ return str(n) + ("th" if 11 <= n % 100 <= 13 else d.get(n % 10, "th"))
28
+
29
+
30
+ class AdPipelineBase(ABC):
31
+ @property
32
+ @abstractmethod
33
+ def inpaint_pipeline(self) -> Callable:
34
+ raise NotImplementedError
35
+
36
+ @property
37
+ @abstractmethod
38
+ def txt2img_class(self) -> type:
39
+ raise NotImplementedError
40
+
41
+ def __call__( # noqa: C901
42
+ self,
43
+ common: Mapping[str, Any] | None = None,
44
+ txt2img_only: Mapping[str, Any] | None = None,
45
+ inpaint_only: Mapping[str, Any] | None = None,
46
+ images: Image.Image | Iterable[Image.Image] | None = None,
47
+ detectors: DetectorType | Iterable[DetectorType] | None = None,
48
+ mask_dilation: int = 4,
49
+ mask_blur: int = 4,
50
+ mask_padding: int = 32,
51
+ ):
52
+ if common is None:
53
+ common = {}
54
+ if txt2img_only is None:
55
+ txt2img_only = {}
56
+ if inpaint_only is None:
57
+ inpaint_only = {}
58
+ if "strength" not in inpaint_only:
59
+ inpaint_only = {**inpaint_only, "strength": 0.4}
60
+
61
+ if detectors is None:
62
+ detectors = [self.default_detector]
63
+ elif not isinstance(detectors, Iterable):
64
+ detectors = [detectors]
65
+
66
+ if images is None:
67
+ txt2img_output = self.process_txt2img(common, txt2img_only)
68
+ txt2img_images = txt2img_output[0]
69
+ else:
70
+ if txt2img_only:
71
+ msg = "Both `images` and `txt2img_only` are specified. if `images` is specified, `txt2img_only` is ignored."
72
+ logger.warning(msg)
73
+
74
+ txt2img_images = [images] if not isinstance(images, Iterable) else images
75
+
76
+ init_images = []
77
+ final_images = []
78
+
79
+ for i, init_image in enumerate(txt2img_images):
80
+ init_images.append(init_image.copy())
81
+ final_image = None
82
+
83
+ for j, detector in enumerate(detectors):
84
+ masks = detector(init_image)
85
+ if masks is None:
86
+ logger.info(
87
+ f"No object detected on {ordinal(i + 1)} image with {ordinal(j + 1)} detector."
88
+ )
89
+ continue
90
+
91
+ for k, mask in enumerate(masks):
92
+ mask = mask.convert("L")
93
+ mask = mask_dilate(mask, mask_dilation)
94
+ bbox = mask.getbbox()
95
+ if bbox is None:
96
+ logger.info(f"No object in {ordinal(k + 1)} mask.")
97
+ continue
98
+ mask = mask_gaussian_blur(mask, mask_blur)
99
+ bbox_padded = bbox_padding(bbox, init_image.size, mask_padding)
100
+
101
+ inpaint_output = self.process_inpainting(
102
+ common,
103
+ inpaint_only,
104
+ init_image,
105
+ mask,
106
+ bbox_padded,
107
+ )
108
+ inpaint_image = inpaint_output[0][0]
109
+
110
+ final_image = composite(
111
+ init_image,
112
+ mask,
113
+ inpaint_image,
114
+ bbox_padded,
115
+ )
116
+ init_image = final_image
117
+
118
+ if final_image is not None:
119
+ final_images.append(final_image)
120
+
121
+ return ADOutput(images=final_images, init_images=init_images)
122
+
123
+ @property
124
+ def default_detector(self) -> Callable[..., list[Image.Image] | None]:
125
+ return yolo_detector
126
+
127
+ def _get_txt2img_args(
128
+ self, common: Mapping[str, Any], txt2img_only: Mapping[str, Any]
129
+ ):
130
+ return {**common, **txt2img_only, "output_type": "pil"}
131
+
132
+ def _get_inpaint_args(
133
+ self, common: Mapping[str, Any], inpaint_only: Mapping[str, Any]
134
+ ):
135
+ common = dict(common)
136
+ sig = inspect.signature(self.inpaint_pipeline)
137
+ if (
138
+ "control_image" in sig.parameters
139
+ and "control_image" not in common
140
+ and "image" in common
141
+ ):
142
+ common["control_image"] = common.pop("image")
143
+ return {
144
+ **common,
145
+ **inpaint_only,
146
+ "num_images_per_prompt": 1,
147
+ "output_type": "pil",
148
+ }
149
+
150
+ def process_txt2img(
151
+ self, common: Mapping[str, Any], txt2img_only: Mapping[str, Any]
152
+ ):
153
+ txt2img_args = self._get_txt2img_args(common, txt2img_only)
154
+ return self.txt2img_class.__call__(self, **txt2img_args)
155
+
156
+ def process_inpainting(
157
+ self,
158
+ common: Mapping[str, Any],
159
+ inpaint_only: Mapping[str, Any],
160
+ init_image: Image.Image,
161
+ mask: Image.Image,
162
+ bbox_padded: tuple[int, int, int, int],
163
+ ):
164
+ crop_image = init_image.crop(bbox_padded)
165
+ crop_mask = mask.crop(bbox_padded)
166
+ inpaint_args = self._get_inpaint_args(common, inpaint_only)
167
+ inpaint_args["image"] = crop_image
168
+ inpaint_args["mask_image"] = crop_mask
169
+
170
+ if "control_image" in inpaint_args:
171
+ inpaint_args["control_image"] = inpaint_args["control_image"].resize(
172
+ crop_image.size
173
+ )
174
+ return self.inpaint_pipeline(**inpaint_args)
asdff/sd.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from functools import cached_property
4
+
5
+ from diffusers import (
6
+ StableDiffusionControlNetInpaintPipeline,
7
+ StableDiffusionControlNetPipeline,
8
+ StableDiffusionInpaintPipeline,
9
+ StableDiffusionPipeline,
10
+ )
11
+
12
+ from asdff.base import AdPipelineBase
13
+
14
+
15
+ class AdPipeline(AdPipelineBase, StableDiffusionPipeline):
16
+ @cached_property
17
+ def inpaint_pipeline(self):
18
+ return StableDiffusionInpaintPipeline(
19
+ vae=self.vae,
20
+ text_encoder=self.text_encoder,
21
+ tokenizer=self.tokenizer,
22
+ unet=self.unet,
23
+ scheduler=self.scheduler,
24
+ safety_checker=self.safety_checker,
25
+ feature_extractor=self.feature_extractor,
26
+ requires_safety_checker=self.config.requires_safety_checker,
27
+ )
28
+
29
+ @property
30
+ def txt2img_class(self):
31
+ return StableDiffusionPipeline
32
+
33
+
34
+ class AdCnPipeline(AdPipelineBase, StableDiffusionControlNetPipeline):
35
+ @cached_property
36
+ def inpaint_pipeline(self):
37
+ return StableDiffusionControlNetInpaintPipeline(
38
+ vae=self.vae,
39
+ text_encoder=self.text_encoder,
40
+ tokenizer=self.tokenizer,
41
+ unet=self.unet,
42
+ controlnet=self.controlnet,
43
+ scheduler=self.scheduler,
44
+ safety_checker=self.safety_checker,
45
+ feature_extractor=self.feature_extractor,
46
+ requires_safety_checker=self.config.requires_safety_checker,
47
+ )
48
+
49
+ @property
50
+ def txt2img_class(self):
51
+ return StableDiffusionControlNetPipeline
asdff/utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from diffusers.utils import BaseOutput
8
+ from PIL import Image, ImageFilter, ImageOps
9
+
10
+
11
+ @dataclass
12
+ class ADOutput(BaseOutput):
13
+ images: list[Image.Image]
14
+ init_images: list[Image.Image]
15
+
16
+
17
+ def mask_dilate(image: Image.Image, value: int = 4) -> Image.Image:
18
+ if value <= 0:
19
+ return image
20
+
21
+ arr = np.array(image)
22
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value))
23
+ dilated = cv2.dilate(arr, kernel, iterations=1)
24
+ return Image.fromarray(dilated)
25
+
26
+
27
+ def mask_gaussian_blur(image: Image.Image, value: int = 4) -> Image.Image:
28
+ if value <= 0:
29
+ return image
30
+
31
+ blur = ImageFilter.GaussianBlur(value)
32
+ return image.filter(blur)
33
+
34
+
35
+ def bbox_padding(
36
+ bbox: tuple[int, int, int, int], image_size: tuple[int, int], value: int = 32
37
+ ) -> tuple[int, int, int, int]:
38
+ if value <= 0:
39
+ return bbox
40
+
41
+ arr = np.array(bbox).reshape(2, 2)
42
+ arr[0] -= value
43
+ arr[1] += value
44
+ arr = np.clip(arr, (0, 0), image_size)
45
+ return tuple(arr.flatten())
46
+
47
+
48
+ def composite(
49
+ init: Image.Image,
50
+ mask: Image.Image,
51
+ gen: Image.Image,
52
+ bbox_padded: tuple[int, int, int, int],
53
+ ) -> Image.Image:
54
+ img_masked = Image.new("RGBa", init.size)
55
+ img_masked.paste(
56
+ init.convert("RGBA").convert("RGBa"),
57
+ mask=ImageOps.invert(mask),
58
+ )
59
+ img_masked = img_masked.convert("RGBA")
60
+
61
+ size = (
62
+ bbox_padded[2] - bbox_padded[0],
63
+ bbox_padded[3] - bbox_padded[1],
64
+ )
65
+ resized = gen.resize(size)
66
+
67
+ output = Image.new("RGBA", init.size)
68
+ output.paste(resized, bbox_padded)
69
+ output.alpha_composite(img_masked)
70
+ return output.convert("RGB")
asdff/yolo.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+ from huggingface_hub import hf_hub_download
8
+ from PIL import Image, ImageDraw
9
+ from torchvision.transforms.functional import to_pil_image
10
+
11
+ try:
12
+ from ultralytics import YOLO
13
+ except ModuleNotFoundError:
14
+ print("Please install ultralytics using `pip install ultralytics`")
15
+ raise
16
+
17
+
18
+ def create_mask_from_bbox(
19
+ bboxes: np.ndarray, shape: tuple[int, int]
20
+ ) -> list[Image.Image]:
21
+ """
22
+ Parameters
23
+ ----------
24
+ bboxes: list[list[float]]
25
+ list of [x1, y1, x2, y2]
26
+ bounding boxes
27
+ shape: tuple[int, int]
28
+ shape of the image (width, height)
29
+
30
+ Returns
31
+ -------
32
+ masks: list[Image.Image]
33
+ A list of masks
34
+
35
+ """
36
+ masks = []
37
+ for bbox in bboxes:
38
+ mask = Image.new("L", shape, "black")
39
+ mask_draw = ImageDraw.Draw(mask)
40
+ mask_draw.rectangle(bbox, fill="white")
41
+ masks.append(mask)
42
+ return masks
43
+
44
+
45
+ def mask_to_pil(masks: torch.Tensor, shape: tuple[int, int]) -> list[Image.Image]:
46
+ """
47
+ Parameters
48
+ ----------
49
+ masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W).
50
+ The device can be CUDA, but `to_pil_image` takes care of that.
51
+
52
+ shape: tuple[int, int]
53
+ (width, height) of the original image
54
+
55
+ Returns
56
+ -------
57
+ images: list[Image.Image]
58
+ """
59
+ n = masks.shape[0]
60
+ return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)]
61
+
62
+
63
+ def yolo_detector(
64
+ image: Image.Image, model_path: str | Path | None = None, confidence: float = 0.3
65
+ ) -> list[Image.Image] | None:
66
+ if not model_path:
67
+ model_path = hf_hub_download("Bingsu/adetailer", "face_yolov8n.pt")
68
+ model = YOLO(model_path)
69
+ pred = model(image, conf=confidence)
70
+
71
+ bboxes = pred[0].boxes.xyxy.cpu().numpy()
72
+ if bboxes.size == 0:
73
+ return None
74
+
75
+ if pred[0].masks is None:
76
+ masks = create_mask_from_bbox(bboxes, image.size)
77
+ else:
78
+ masks = mask_to_pil(pred[0].masks.data, image.size)
79
+
80
+ return masks
pipeline.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from asdff import AdPipeline # noqa: F401