Bingsu commited on
Commit
ae02308
1 Parent(s): 13fb3f1

Upload files: v0.1.1

Browse files
Files changed (6) hide show
  1. asdff/__init__.py +9 -0
  2. asdff/__version__.py +1 -0
  3. asdff/sd.py +123 -0
  4. asdff/utils.py +70 -0
  5. asdff/yolo.py +73 -0
  6. pipeline.py +1 -0
asdff/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .__version__ import __version__
2
+ from .sd import AdPipeline
3
+ from .yolo import yolo_detector
4
+
5
+ __all__ = [
6
+ "AdPipeline",
7
+ "yolo_detector",
8
+ "__version__",
9
+ ]
asdff/__version__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.1.1"
asdff/sd.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from functools import cached_property
4
+ from typing import Any, Callable, Iterable, List, Mapping, Optional
5
+
6
+ from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
7
+ from diffusers.utils import logging
8
+ from PIL import Image
9
+
10
+ from asdff.utils import (
11
+ ADOutput,
12
+ bbox_padding,
13
+ composite,
14
+ mask_dilate,
15
+ mask_gaussian_blur,
16
+ )
17
+ from asdff.yolo import yolo_detector
18
+
19
+ logger = logging.get_logger("diffusers")
20
+
21
+
22
+ DetectorType = Callable[[Image.Image], Optional[List[Image.Image]]]
23
+
24
+
25
+ def ordinal(n: int) -> str:
26
+ d = {1: "st", 2: "nd", 3: "rd"}
27
+ return str(n) + ("th" if 11 <= n % 100 <= 13 else d.get(n % 10, "th"))
28
+
29
+
30
+ class AdPipeline(StableDiffusionPipeline):
31
+ @cached_property
32
+ def inpaint_pipeline(self):
33
+ return StableDiffusionInpaintPipeline(
34
+ vae=self.vae,
35
+ text_encoder=self.text_encoder,
36
+ tokenizer=self.tokenizer,
37
+ unet=self.unet,
38
+ scheduler=self.scheduler,
39
+ safety_checker=self.safety_checker,
40
+ feature_extractor=self.feature_extractor,
41
+ requires_safety_checker=self.config.requires_safety_checker,
42
+ )
43
+
44
+ def __call__( # noqa: C901
45
+ self,
46
+ common: Mapping[str, Any] | None = None,
47
+ txt2img_only: Mapping[str, Any] | None = None,
48
+ inpaint_only: Mapping[str, Any] | None = None,
49
+ detectors: DetectorType | Iterable[DetectorType] | None = None,
50
+ mask_dilation: int = 4,
51
+ mask_blur: int = 4,
52
+ mask_padding: int = 32,
53
+ ):
54
+ if common is None:
55
+ common = {}
56
+ if txt2img_only is None:
57
+ txt2img_only = {}
58
+ if inpaint_only is None:
59
+ inpaint_only = {}
60
+ if "strength" not in inpaint_only:
61
+ inpaint_only = {**inpaint_only, "strength": 0.4}
62
+
63
+ if detectors is None:
64
+ detectors = [self.default_detector]
65
+ elif callable(detectors):
66
+ detectors = [detectors]
67
+
68
+ txt2img_output = super().__call__(**common, **txt2img_only, output_type="pil")
69
+ txt2img_images: list[Image.Image] = txt2img_output[0]
70
+
71
+ init_images = []
72
+ final_images = []
73
+
74
+ for i, init_image in enumerate(txt2img_images):
75
+ init_images.append(init_image.copy())
76
+ final_image = None
77
+
78
+ for j, detector in enumerate(detectors):
79
+ masks = detector(init_image)
80
+ if masks is None:
81
+ logger.info(
82
+ f"No object detected on {ordinal(i + 1)} image with {ordinal(j + 1)} detector."
83
+ )
84
+ continue
85
+
86
+ for k, mask in enumerate(masks):
87
+ mask = mask.convert("L")
88
+ mask = mask_dilate(mask, mask_dilation)
89
+ bbox = mask.getbbox()
90
+ if bbox is None:
91
+ logger.info(f"No object in {ordinal(k + 1)} mask.")
92
+ continue
93
+ mask = mask_gaussian_blur(mask, mask_blur)
94
+ bbox_padded = bbox_padding(bbox, init_image.size, mask_padding)
95
+
96
+ crop_image = init_image.crop(bbox_padded)
97
+ crop_mask = mask.crop(bbox_padded)
98
+
99
+ inpaint_output = self.inpaint_pipeline(
100
+ **common,
101
+ **inpaint_only,
102
+ image=crop_image,
103
+ mask_image=crop_mask,
104
+ num_images_per_prompt=1,
105
+ output_type="pil",
106
+ )
107
+ inpaint_image: Image.Image = inpaint_output[0][0]
108
+ final_image = composite(
109
+ init=init_image,
110
+ mask=mask,
111
+ gen=inpaint_image,
112
+ bbox_padded=bbox_padded,
113
+ )
114
+ init_image = final_image
115
+
116
+ if final_image is not None:
117
+ final_images.append(final_image)
118
+
119
+ return ADOutput(images=final_images, init_images=init_images)
120
+
121
+ @property
122
+ def default_detector(self) -> Callable[..., list[Image.Image] | None]:
123
+ return yolo_detector
asdff/utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from diffusers.utils import BaseOutput
8
+ from PIL import Image, ImageFilter, ImageOps
9
+
10
+
11
+ @dataclass
12
+ class ADOutput(BaseOutput):
13
+ images: list[Image.Image]
14
+ init_images: list[Image.Image]
15
+
16
+
17
+ def mask_dilate(image: Image.Image, value: int = 4) -> Image.Image:
18
+ if value <= 0:
19
+ return image
20
+
21
+ arr = np.array(image)
22
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value))
23
+ dilated = cv2.dilate(arr, kernel, iterations=1)
24
+ return Image.fromarray(dilated)
25
+
26
+
27
+ def mask_gaussian_blur(image: Image.Image, value: int = 4) -> Image.Image:
28
+ if value <= 0:
29
+ return image
30
+
31
+ blur = ImageFilter.GaussianBlur(value)
32
+ return image.filter(blur)
33
+
34
+
35
+ def bbox_padding(
36
+ bbox: tuple[int, int, int, int], image_size: tuple[int, int], value: int = 32
37
+ ) -> tuple[int, int, int, int]:
38
+ if value <= 0:
39
+ return bbox
40
+
41
+ arr = np.array(bbox).reshape(2, 2)
42
+ arr[0] -= value
43
+ arr[1] += value
44
+ arr = np.clip(arr, (0, 0), image_size)
45
+ return tuple(arr.flatten())
46
+
47
+
48
+ def composite(
49
+ init: Image.Image,
50
+ mask: Image.Image,
51
+ gen: Image.Image,
52
+ bbox_padded: tuple[int, int, int, int],
53
+ ) -> Image.Image:
54
+ img_masked = Image.new("RGBa", init.size)
55
+ img_masked.paste(
56
+ init.convert("RGBA").convert("RGBa"),
57
+ mask=ImageOps.invert(mask),
58
+ )
59
+ img_masked = img_masked.convert("RGBA")
60
+
61
+ size = (
62
+ bbox_padded[2] - bbox_padded[0],
63
+ bbox_padded[3] - bbox_padded[1],
64
+ )
65
+ resized = gen.resize(size)
66
+
67
+ output = Image.new("RGBA", init.size)
68
+ output.paste(resized, bbox_padded)
69
+ output.alpha_composite(img_masked)
70
+ return output.convert("RGB")
asdff/yolo.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+ from PIL import Image, ImageDraw
7
+ from torchvision.transforms.functional import to_pil_image
8
+ from ultralytics import YOLO
9
+
10
+
11
+ def create_mask_from_bbox(
12
+ bboxes: np.ndarray, shape: tuple[int, int]
13
+ ) -> list[Image.Image]:
14
+ """
15
+ Parameters
16
+ ----------
17
+ bboxes: list[list[float]]
18
+ list of [x1, y1, x2, y2]
19
+ bounding boxes
20
+ shape: tuple[int, int]
21
+ shape of the image (width, height)
22
+
23
+ Returns
24
+ -------
25
+ masks: list[Image.Image]
26
+ A list of masks
27
+
28
+ """
29
+ masks = []
30
+ for bbox in bboxes:
31
+ mask = Image.new("L", shape, "black")
32
+ mask_draw = ImageDraw.Draw(mask)
33
+ mask_draw.rectangle(bbox, fill="white")
34
+ masks.append(mask)
35
+ return masks
36
+
37
+
38
+ def mask_to_pil(masks: torch.Tensor, shape: tuple[int, int]) -> list[Image.Image]:
39
+ """
40
+ Parameters
41
+ ----------
42
+ masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W).
43
+ The device can be CUDA, but `to_pil_image` takes care of that.
44
+
45
+ shape: tuple[int, int]
46
+ (width, height) of the original image
47
+
48
+ Returns
49
+ -------
50
+ images: list[Image.Image]
51
+ """
52
+ n = masks.shape[0]
53
+ return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)]
54
+
55
+
56
+ def yolo_detector(
57
+ image: Image.Image, model_path: str | None = None, confidence: float = 0.3
58
+ ) -> list[Image.Image] | None:
59
+ if not model_path:
60
+ model_path = hf_hub_download("Bingsu/adetailer", "face_yolov8n.pt")
61
+ model = YOLO(model_path)
62
+ pred = model(image, conf=confidence)
63
+
64
+ bboxes = pred[0].boxes.xyxy.cpu().numpy()
65
+ if bboxes.size == 0:
66
+ return None
67
+
68
+ if pred[0].masks is None:
69
+ masks = create_mask_from_bbox(bboxes, image.size)
70
+ else:
71
+ masks = mask_to_pil(pred[0].masks.data, image.size)
72
+
73
+ return masks
pipeline.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from asdff import AdPipeline # noqa: F401