jayparmr commited on
Commit
a3d6c18
·
1 Parent(s): 469e0ba

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +0 -1
  2. carvekit/__init__.py +1 -0
  3. carvekit/__main__.py +149 -0
  4. carvekit/api/__init__.py +0 -0
  5. carvekit/api/high.py +100 -0
  6. carvekit/api/interface.py +77 -0
  7. carvekit/ml/__init__.py +4 -0
  8. carvekit/ml/arch/__init__.py +0 -0
  9. carvekit/ml/arch/basnet/__init__.py +0 -0
  10. carvekit/ml/arch/basnet/basnet.py +478 -0
  11. carvekit/ml/arch/fba_matting/__init__.py +0 -0
  12. carvekit/ml/arch/fba_matting/layers_WS.py +57 -0
  13. carvekit/ml/arch/fba_matting/models.py +341 -0
  14. carvekit/ml/arch/fba_matting/resnet_GN_WS.py +151 -0
  15. carvekit/ml/arch/fba_matting/resnet_bn.py +169 -0
  16. carvekit/ml/arch/fba_matting/transforms.py +45 -0
  17. carvekit/ml/arch/tracerb7/__init__.py +0 -0
  18. carvekit/ml/arch/tracerb7/att_modules.py +290 -0
  19. carvekit/ml/arch/tracerb7/conv_modules.py +88 -0
  20. carvekit/ml/arch/tracerb7/effi_utils.py +579 -0
  21. carvekit/ml/arch/tracerb7/efficientnet.py +325 -0
  22. carvekit/ml/arch/tracerb7/tracer.py +97 -0
  23. carvekit/ml/arch/u2net/__init__.py +0 -0
  24. carvekit/ml/arch/u2net/u2net.py +172 -0
  25. carvekit/ml/files/__init__.py +7 -0
  26. carvekit/ml/files/models_loc.py +70 -0
  27. carvekit/ml/wrap/__init__.py +0 -0
  28. carvekit/ml/wrap/basnet.py +141 -0
  29. carvekit/ml/wrap/deeplab_v3.py +150 -0
  30. carvekit/ml/wrap/fba_matting.py +224 -0
  31. carvekit/ml/wrap/tracer_b7.py +178 -0
  32. carvekit/ml/wrap/u2net.py +140 -0
  33. carvekit/pipelines/__init__.py +0 -0
  34. carvekit/pipelines/postprocessing.py +76 -0
  35. carvekit/pipelines/preprocessing.py +28 -0
  36. carvekit/trimap/__init__.py +0 -0
  37. carvekit/trimap/add_ops.py +91 -0
  38. carvekit/trimap/cv_gen.py +64 -0
  39. carvekit/trimap/generator.py +47 -0
  40. carvekit/utils/__init__.py +0 -0
  41. carvekit/utils/download_models.py +214 -0
  42. carvekit/utils/fs_utils.py +38 -0
  43. carvekit/utils/image_utils.py +150 -0
  44. carvekit/utils/mask_utils.py +85 -0
  45. carvekit/utils/models_utils.py +126 -0
  46. carvekit/utils/pool_utils.py +40 -0
  47. carvekit/web/__init__.py +0 -0
  48. carvekit/web/app.py +30 -0
  49. carvekit/web/deps.py +6 -0
  50. carvekit/web/handlers/__init__.py +0 -0
.gitattributes CHANGED
@@ -25,7 +25,6 @@
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
carvekit/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ version = "4.1.0"
carvekit/__main__.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import click
4
+ import tqdm
5
+
6
+ from carvekit.utils.image_utils import ALLOWED_SUFFIXES
7
+ from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
8
+ from carvekit.web.schemas.config import MLConfig
9
+ from carvekit.web.utils.init_utils import init_interface
10
+ from carvekit.utils.fs_utils import save_file
11
+
12
+
13
+ @click.command(
14
+ "removebg",
15
+ help="Performs background removal on specified photos using console interface.",
16
+ )
17
+ @click.option("-i", required=True, type=str, help="Path to input file or dir")
18
+ @click.option("-o", default="none", type=str, help="Path to output file or dir")
19
+ @click.option("--pre", default="none", type=str, help="Preprocessing method")
20
+ @click.option("--post", default="fba", type=str, help="Postprocessing method.")
21
+ @click.option("--net", default="tracer_b7", type=str, help="Segmentation Network")
22
+ @click.option(
23
+ "--recursive",
24
+ default=False,
25
+ type=bool,
26
+ help="Enables recursive search for images in a folder",
27
+ )
28
+ @click.option(
29
+ "--batch_size",
30
+ default=10,
31
+ type=int,
32
+ help="Batch Size for list of images to be loaded to RAM",
33
+ )
34
+ @click.option(
35
+ "--batch_size_seg",
36
+ default=5,
37
+ type=int,
38
+ help="Batch size for list of images to be processed by segmentation " "network",
39
+ )
40
+ @click.option(
41
+ "--batch_size_mat",
42
+ default=1,
43
+ type=int,
44
+ help="Batch size for list of images to be processed by matting " "network",
45
+ )
46
+ @click.option(
47
+ "--seg_mask_size",
48
+ default=640,
49
+ type=int,
50
+ help="The size of the input image for the segmentation neural network.",
51
+ )
52
+ @click.option(
53
+ "--matting_mask_size",
54
+ default=2048,
55
+ type=int,
56
+ help="The size of the input image for the matting neural network.",
57
+ )
58
+ @click.option(
59
+ "--trimap_dilation",
60
+ default=30,
61
+ type=int,
62
+ help="The size of the offset radius from the object mask in "
63
+ "pixels when forming an unknown area",
64
+ )
65
+ @click.option(
66
+ "--trimap_erosion",
67
+ default=5,
68
+ type=int,
69
+ help="The number of iterations of erosion that the object's "
70
+ "mask will be subjected to before forming an unknown area",
71
+ )
72
+ @click.option(
73
+ "--trimap_prob_threshold",
74
+ default=231,
75
+ type=int,
76
+ help="Probability threshold at which the prob_filter "
77
+ "and prob_as_unknown_area operations will be "
78
+ "applied",
79
+ )
80
+ @click.option("--device", default="cpu", type=str, help="Processing Device.")
81
+ @click.option(
82
+ "--fp16", default=False, type=bool, help="Enables mixed precision processing."
83
+ )
84
+ def removebg(
85
+ i: str,
86
+ o: str,
87
+ pre: str,
88
+ post: str,
89
+ net: str,
90
+ recursive: bool,
91
+ batch_size: int,
92
+ batch_size_seg: int,
93
+ batch_size_mat: int,
94
+ seg_mask_size: int,
95
+ matting_mask_size: int,
96
+ device: str,
97
+ fp16: bool,
98
+ trimap_dilation: int,
99
+ trimap_erosion: int,
100
+ trimap_prob_threshold: int,
101
+ ):
102
+ out_path = Path(o)
103
+ input_path = Path(i)
104
+ if input_path.is_dir():
105
+ if recursive:
106
+ all_images = input_path.rglob("*.*")
107
+ else:
108
+ all_images = input_path.glob("*.*")
109
+ all_images = [
110
+ i
111
+ for i in all_images
112
+ if i.suffix.lower() in ALLOWED_SUFFIXES and "_bg_removed" not in i.name
113
+ ]
114
+ else:
115
+ all_images = [input_path]
116
+
117
+ interface_config = MLConfig(
118
+ segmentation_network=net,
119
+ preprocessing_method=pre,
120
+ postprocessing_method=post,
121
+ device=device,
122
+ batch_size_seg=batch_size_seg,
123
+ batch_size_matting=batch_size_mat,
124
+ seg_mask_size=seg_mask_size,
125
+ matting_mask_size=matting_mask_size,
126
+ fp16=fp16,
127
+ trimap_dilation=trimap_dilation,
128
+ trimap_erosion=trimap_erosion,
129
+ trimap_prob_threshold=trimap_prob_threshold,
130
+ )
131
+
132
+ interface = init_interface(interface_config)
133
+
134
+ for image_batch in tqdm.tqdm(
135
+ batch_generator(all_images, n=batch_size),
136
+ total=int(len(all_images) / batch_size),
137
+ desc="Removing background",
138
+ unit=" image batch",
139
+ colour="blue",
140
+ ):
141
+ images_without_background = interface(image_batch) # Remove background
142
+ thread_pool_processing(
143
+ lambda x: save_file(out_path, image_batch[x], images_without_background[x]),
144
+ range((len(image_batch))),
145
+ ) # Drop images to fs
146
+
147
+
148
+ if __name__ == "__main__":
149
+ removebg()
carvekit/api/__init__.py ADDED
File without changes
carvekit/api/high.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ import warnings
7
+
8
+ from carvekit.api.interface import Interface
9
+ from carvekit.ml.wrap.fba_matting import FBAMatting
10
+ from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
11
+ from carvekit.ml.wrap.u2net import U2NET
12
+ from carvekit.pipelines.postprocessing import MattingMethod
13
+ from carvekit.trimap.generator import TrimapGenerator
14
+
15
+
16
+ class HiInterface(Interface):
17
+ def __init__(
18
+ self,
19
+ object_type: str = "object",
20
+ batch_size_seg=2,
21
+ batch_size_matting=1,
22
+ device="cpu",
23
+ seg_mask_size=640,
24
+ matting_mask_size=2048,
25
+ trimap_prob_threshold=231,
26
+ trimap_dilation=30,
27
+ trimap_erosion_iters=5,
28
+ fp16=False,
29
+ ):
30
+ """
31
+ Initializes High Level interface.
32
+
33
+ Args:
34
+ object_type: Interest object type. Can be "object" or "hairs-like".
35
+ matting_mask_size: The size of the input image for the matting neural network.
36
+ seg_mask_size: The size of the input image for the segmentation neural network.
37
+ batch_size_seg: Number of images processed per one segmentation neural network call.
38
+ batch_size_matting: Number of images processed per one matting neural network call.
39
+ device: Processing device
40
+ fp16: Use half precision. Reduce memory usage and increase speed. Experimental support
41
+ trimap_prob_threshold: Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied
42
+ trimap_dilation: The size of the offset radius from the object mask in pixels when forming an unknown area
43
+ trimap_erosion_iters: The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area
44
+
45
+ Notes:
46
+ 1. Changing seg_mask_size may cause an out-of-memory error if the value is too large, and it may also
47
+ result in reduced precision. I do not recommend changing this value. You can change matting_mask_size in
48
+ range from (1024 to 4096) to improve object edge refining quality, but it will cause extra large RAM and
49
+ video memory consume. Also, you can change batch size to accelerate background removal, but it also causes
50
+ extra large video memory consume, if value is too big.
51
+
52
+ 2. Changing trimap_prob_threshold, trimap_kernel_size, trimap_erosion_iters may improve object edge
53
+ refining quality,
54
+ """
55
+ if object_type == "object":
56
+ self.u2net = TracerUniversalB7(
57
+ device=device,
58
+ batch_size=batch_size_seg,
59
+ input_image_size=seg_mask_size,
60
+ fp16=fp16,
61
+ )
62
+ elif object_type == "hairs-like":
63
+ self.u2net = U2NET(
64
+ device=device,
65
+ batch_size=batch_size_seg,
66
+ input_image_size=seg_mask_size,
67
+ fp16=fp16,
68
+ )
69
+ else:
70
+ warnings.warn(
71
+ f"Unknown object type: {object_type}. Using default object type: object"
72
+ )
73
+ self.u2net = TracerUniversalB7(
74
+ device=device,
75
+ batch_size=batch_size_seg,
76
+ input_image_size=seg_mask_size,
77
+ fp16=fp16,
78
+ )
79
+
80
+ self.fba = FBAMatting(
81
+ batch_size=batch_size_matting,
82
+ device=device,
83
+ input_tensor_size=matting_mask_size,
84
+ fp16=fp16,
85
+ )
86
+ self.trimap_generator = TrimapGenerator(
87
+ prob_threshold=trimap_prob_threshold,
88
+ kernel_size=trimap_dilation,
89
+ erosion_iters=trimap_erosion_iters,
90
+ )
91
+ super(HiInterface, self).__init__(
92
+ pre_pipe=None,
93
+ seg_pipe=self.u2net,
94
+ post_pipe=MattingMethod(
95
+ matting_module=self.fba,
96
+ trimap_generator=self.trimap_generator,
97
+ device=device,
98
+ ),
99
+ device=device,
100
+ )
carvekit/api/interface.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ from pathlib import Path
7
+ from typing import Union, List, Optional
8
+
9
+ from PIL import Image
10
+
11
+ from carvekit.ml.wrap.basnet import BASNET
12
+ from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
13
+ from carvekit.ml.wrap.u2net import U2NET
14
+ from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
15
+ from carvekit.pipelines.preprocessing import PreprocessingStub
16
+ from carvekit.pipelines.postprocessing import MattingMethod
17
+ from carvekit.utils.image_utils import load_image
18
+ from carvekit.utils.mask_utils import apply_mask
19
+ from carvekit.utils.pool_utils import thread_pool_processing
20
+
21
+
22
+ class Interface:
23
+ def __init__(
24
+ self,
25
+ seg_pipe: Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7],
26
+ pre_pipe: Optional[Union[PreprocessingStub]] = None,
27
+ post_pipe: Optional[Union[MattingMethod]] = None,
28
+ device="cpu",
29
+ ):
30
+ """
31
+ Initializes an object for interacting with pipelines and other components of the CarveKit framework.
32
+
33
+ Args:
34
+ pre_pipe: Initialized pre-processing pipeline object
35
+ seg_pipe: Initialized segmentation network object
36
+ post_pipe: Initialized postprocessing pipeline object
37
+ device: The processing device that will be used to apply the masks to the images.
38
+ """
39
+ self.device = device
40
+ self.preprocessing_pipeline = pre_pipe
41
+ self.segmentation_pipeline = seg_pipe
42
+ self.postprocessing_pipeline = post_pipe
43
+
44
+ def __call__(
45
+ self, images: List[Union[str, Path, Image.Image]]
46
+ ) -> List[Image.Image]:
47
+ """
48
+ Removes the background from the specified images.
49
+
50
+ Args:
51
+ images: list of input images
52
+
53
+ Returns:
54
+ List of images without background as PIL.Image.Image instances
55
+ """
56
+ images = thread_pool_processing(load_image, images)
57
+ if self.preprocessing_pipeline is not None:
58
+ masks: List[Image.Image] = self.preprocessing_pipeline(
59
+ interface=self, images=images
60
+ )
61
+ else:
62
+ masks: List[Image.Image] = self.segmentation_pipeline(images=images)
63
+
64
+ if self.postprocessing_pipeline is not None:
65
+ images: List[Image.Image] = self.postprocessing_pipeline(
66
+ images=images, masks=masks
67
+ )
68
+ else:
69
+ images = list(
70
+ map(
71
+ lambda x: apply_mask(
72
+ image=images[x], mask=masks[x], device=self.device
73
+ ),
74
+ range(len(images)),
75
+ )
76
+ )
77
+ return images
carvekit/ml/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from carvekit.utils.models_utils import fix_seed, suppress_warnings
2
+
3
+ fix_seed()
4
+ suppress_warnings()
carvekit/ml/arch/__init__.py ADDED
File without changes
carvekit/ml/arch/basnet/__init__.py ADDED
File without changes
carvekit/ml/arch/basnet/basnet.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/NathanUA/BASNet
3
+ Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: MIT License
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ from torchvision import models
9
+
10
+
11
+ def conv3x3(in_planes, out_planes, stride=1):
12
+ """3x3 convolution with padding"""
13
+ return nn.Conv2d(
14
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
15
+ )
16
+
17
+
18
+ class BasicBlock(nn.Module):
19
+ expansion = 1
20
+
21
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = conv3x3(inplanes, planes, stride)
24
+ self.bn1 = nn.BatchNorm2d(planes)
25
+ self.relu = nn.ReLU(inplace=True)
26
+ self.conv2 = conv3x3(planes, planes)
27
+ self.bn2 = nn.BatchNorm2d(planes)
28
+ self.downsample = downsample
29
+ self.stride = stride
30
+
31
+ def forward(self, x):
32
+ residual = x
33
+
34
+ out = self.conv1(x)
35
+ out = self.bn1(out)
36
+ out = self.relu(out)
37
+
38
+ out = self.conv2(out)
39
+ out = self.bn2(out)
40
+
41
+ if self.downsample is not None:
42
+ residual = self.downsample(x)
43
+
44
+ out += residual
45
+ out = self.relu(out)
46
+
47
+ return out
48
+
49
+
50
+ class BasicBlockDe(nn.Module):
51
+ expansion = 1
52
+
53
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
54
+ super(BasicBlockDe, self).__init__()
55
+
56
+ self.convRes = conv3x3(inplanes, planes, stride)
57
+ self.bnRes = nn.BatchNorm2d(planes)
58
+ self.reluRes = nn.ReLU(inplace=True)
59
+
60
+ self.conv1 = conv3x3(inplanes, planes, stride)
61
+ self.bn1 = nn.BatchNorm2d(planes)
62
+ self.relu = nn.ReLU(inplace=True)
63
+ self.conv2 = conv3x3(planes, planes)
64
+ self.bn2 = nn.BatchNorm2d(planes)
65
+ self.downsample = downsample
66
+ self.stride = stride
67
+
68
+ def forward(self, x):
69
+ residual = self.convRes(x)
70
+ residual = self.bnRes(residual)
71
+ residual = self.reluRes(residual)
72
+
73
+ out = self.conv1(x)
74
+ out = self.bn1(out)
75
+ out = self.relu(out)
76
+
77
+ out = self.conv2(out)
78
+ out = self.bn2(out)
79
+
80
+ if self.downsample is not None:
81
+ residual = self.downsample(x)
82
+
83
+ out += residual
84
+ out = self.relu(out)
85
+
86
+ return out
87
+
88
+
89
+ class Bottleneck(nn.Module):
90
+ expansion = 4
91
+
92
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
93
+ super(Bottleneck, self).__init__()
94
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
95
+ self.bn1 = nn.BatchNorm2d(planes)
96
+ self.conv2 = nn.Conv2d(
97
+ planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
98
+ )
99
+ self.bn2 = nn.BatchNorm2d(planes)
100
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
101
+ self.bn3 = nn.BatchNorm2d(planes * 4)
102
+ self.relu = nn.ReLU(inplace=True)
103
+ self.downsample = downsample
104
+ self.stride = stride
105
+
106
+ def forward(self, x):
107
+ residual = x
108
+
109
+ out = self.conv1(x)
110
+ out = self.bn1(out)
111
+ out = self.relu(out)
112
+
113
+ out = self.conv2(out)
114
+ out = self.bn2(out)
115
+ out = self.relu(out)
116
+
117
+ out = self.conv3(out)
118
+ out = self.bn3(out)
119
+
120
+ if self.downsample is not None:
121
+ residual = self.downsample(x)
122
+
123
+ out += residual
124
+ out = self.relu(out)
125
+
126
+ return out
127
+
128
+
129
+ class RefUnet(nn.Module):
130
+ def __init__(self, in_ch, inc_ch):
131
+ super(RefUnet, self).__init__()
132
+
133
+ self.conv0 = nn.Conv2d(in_ch, inc_ch, 3, padding=1)
134
+
135
+ self.conv1 = nn.Conv2d(inc_ch, 64, 3, padding=1)
136
+ self.bn1 = nn.BatchNorm2d(64)
137
+ self.relu1 = nn.ReLU(inplace=True)
138
+
139
+ self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True)
140
+
141
+ self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
142
+ self.bn2 = nn.BatchNorm2d(64)
143
+ self.relu2 = nn.ReLU(inplace=True)
144
+
145
+ self.pool2 = nn.MaxPool2d(2, 2, ceil_mode=True)
146
+
147
+ self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
148
+ self.bn3 = nn.BatchNorm2d(64)
149
+ self.relu3 = nn.ReLU(inplace=True)
150
+
151
+ self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True)
152
+
153
+ self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
154
+ self.bn4 = nn.BatchNorm2d(64)
155
+ self.relu4 = nn.ReLU(inplace=True)
156
+
157
+ self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
158
+
159
+ self.conv5 = nn.Conv2d(64, 64, 3, padding=1)
160
+ self.bn5 = nn.BatchNorm2d(64)
161
+ self.relu5 = nn.ReLU(inplace=True)
162
+
163
+ self.conv_d4 = nn.Conv2d(128, 64, 3, padding=1)
164
+ self.bn_d4 = nn.BatchNorm2d(64)
165
+ self.relu_d4 = nn.ReLU(inplace=True)
166
+
167
+ self.conv_d3 = nn.Conv2d(128, 64, 3, padding=1)
168
+ self.bn_d3 = nn.BatchNorm2d(64)
169
+ self.relu_d3 = nn.ReLU(inplace=True)
170
+
171
+ self.conv_d2 = nn.Conv2d(128, 64, 3, padding=1)
172
+ self.bn_d2 = nn.BatchNorm2d(64)
173
+ self.relu_d2 = nn.ReLU(inplace=True)
174
+
175
+ self.conv_d1 = nn.Conv2d(128, 64, 3, padding=1)
176
+ self.bn_d1 = nn.BatchNorm2d(64)
177
+ self.relu_d1 = nn.ReLU(inplace=True)
178
+
179
+ self.conv_d0 = nn.Conv2d(64, 1, 3, padding=1)
180
+
181
+ self.upscore2 = nn.Upsample(
182
+ scale_factor=2, mode="bilinear", align_corners=False
183
+ )
184
+
185
+ def forward(self, x):
186
+ hx = x
187
+ hx = self.conv0(hx)
188
+
189
+ hx1 = self.relu1(self.bn1(self.conv1(hx)))
190
+ hx = self.pool1(hx1)
191
+
192
+ hx2 = self.relu2(self.bn2(self.conv2(hx)))
193
+ hx = self.pool2(hx2)
194
+
195
+ hx3 = self.relu3(self.bn3(self.conv3(hx)))
196
+ hx = self.pool3(hx3)
197
+
198
+ hx4 = self.relu4(self.bn4(self.conv4(hx)))
199
+ hx = self.pool4(hx4)
200
+
201
+ hx5 = self.relu5(self.bn5(self.conv5(hx)))
202
+
203
+ hx = self.upscore2(hx5)
204
+
205
+ d4 = self.relu_d4(self.bn_d4(self.conv_d4(torch.cat((hx, hx4), 1))))
206
+ hx = self.upscore2(d4)
207
+
208
+ d3 = self.relu_d3(self.bn_d3(self.conv_d3(torch.cat((hx, hx3), 1))))
209
+ hx = self.upscore2(d3)
210
+
211
+ d2 = self.relu_d2(self.bn_d2(self.conv_d2(torch.cat((hx, hx2), 1))))
212
+ hx = self.upscore2(d2)
213
+
214
+ d1 = self.relu_d1(self.bn_d1(self.conv_d1(torch.cat((hx, hx1), 1))))
215
+
216
+ residual = self.conv_d0(d1)
217
+
218
+ return x + residual
219
+
220
+
221
+ class BASNet(nn.Module):
222
+ def __init__(self, n_channels, n_classes):
223
+ super(BASNet, self).__init__()
224
+
225
+ resnet = models.resnet34(pretrained=False)
226
+
227
+ # -------------Encoder--------------
228
+
229
+ self.inconv = nn.Conv2d(n_channels, 64, 3, padding=1)
230
+ self.inbn = nn.BatchNorm2d(64)
231
+ self.inrelu = nn.ReLU(inplace=True)
232
+
233
+ # stage 1
234
+ self.encoder1 = resnet.layer1 # 224
235
+ # stage 2
236
+ self.encoder2 = resnet.layer2 # 112
237
+ # stage 3
238
+ self.encoder3 = resnet.layer3 # 56
239
+ # stage 4
240
+ self.encoder4 = resnet.layer4 # 28
241
+
242
+ self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
243
+
244
+ # stage 5
245
+ self.resb5_1 = BasicBlock(512, 512)
246
+ self.resb5_2 = BasicBlock(512, 512)
247
+ self.resb5_3 = BasicBlock(512, 512) # 14
248
+
249
+ self.pool5 = nn.MaxPool2d(2, 2, ceil_mode=True)
250
+
251
+ # stage 6
252
+ self.resb6_1 = BasicBlock(512, 512)
253
+ self.resb6_2 = BasicBlock(512, 512)
254
+ self.resb6_3 = BasicBlock(512, 512) # 7
255
+
256
+ # -------------Bridge--------------
257
+
258
+ # stage Bridge
259
+ self.convbg_1 = nn.Conv2d(512, 512, 3, dilation=2, padding=2) # 7
260
+ self.bnbg_1 = nn.BatchNorm2d(512)
261
+ self.relubg_1 = nn.ReLU(inplace=True)
262
+ self.convbg_m = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
263
+ self.bnbg_m = nn.BatchNorm2d(512)
264
+ self.relubg_m = nn.ReLU(inplace=True)
265
+ self.convbg_2 = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
266
+ self.bnbg_2 = nn.BatchNorm2d(512)
267
+ self.relubg_2 = nn.ReLU(inplace=True)
268
+
269
+ # -------------Decoder--------------
270
+
271
+ # stage 6d
272
+ self.conv6d_1 = nn.Conv2d(1024, 512, 3, padding=1) # 16
273
+ self.bn6d_1 = nn.BatchNorm2d(512)
274
+ self.relu6d_1 = nn.ReLU(inplace=True)
275
+
276
+ self.conv6d_m = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
277
+ self.bn6d_m = nn.BatchNorm2d(512)
278
+ self.relu6d_m = nn.ReLU(inplace=True)
279
+
280
+ self.conv6d_2 = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
281
+ self.bn6d_2 = nn.BatchNorm2d(512)
282
+ self.relu6d_2 = nn.ReLU(inplace=True)
283
+
284
+ # stage 5d
285
+ self.conv5d_1 = nn.Conv2d(1024, 512, 3, padding=1) # 16
286
+ self.bn5d_1 = nn.BatchNorm2d(512)
287
+ self.relu5d_1 = nn.ReLU(inplace=True)
288
+
289
+ self.conv5d_m = nn.Conv2d(512, 512, 3, padding=1)
290
+ self.bn5d_m = nn.BatchNorm2d(512)
291
+ self.relu5d_m = nn.ReLU(inplace=True)
292
+
293
+ self.conv5d_2 = nn.Conv2d(512, 512, 3, padding=1)
294
+ self.bn5d_2 = nn.BatchNorm2d(512)
295
+ self.relu5d_2 = nn.ReLU(inplace=True)
296
+
297
+ # stage 4d
298
+ self.conv4d_1 = nn.Conv2d(1024, 512, 3, padding=1) # 32
299
+ self.bn4d_1 = nn.BatchNorm2d(512)
300
+ self.relu4d_1 = nn.ReLU(inplace=True)
301
+
302
+ self.conv4d_m = nn.Conv2d(512, 512, 3, padding=1)
303
+ self.bn4d_m = nn.BatchNorm2d(512)
304
+ self.relu4d_m = nn.ReLU(inplace=True)
305
+
306
+ self.conv4d_2 = nn.Conv2d(512, 256, 3, padding=1)
307
+ self.bn4d_2 = nn.BatchNorm2d(256)
308
+ self.relu4d_2 = nn.ReLU(inplace=True)
309
+
310
+ # stage 3d
311
+ self.conv3d_1 = nn.Conv2d(512, 256, 3, padding=1) # 64
312
+ self.bn3d_1 = nn.BatchNorm2d(256)
313
+ self.relu3d_1 = nn.ReLU(inplace=True)
314
+
315
+ self.conv3d_m = nn.Conv2d(256, 256, 3, padding=1)
316
+ self.bn3d_m = nn.BatchNorm2d(256)
317
+ self.relu3d_m = nn.ReLU(inplace=True)
318
+
319
+ self.conv3d_2 = nn.Conv2d(256, 128, 3, padding=1)
320
+ self.bn3d_2 = nn.BatchNorm2d(128)
321
+ self.relu3d_2 = nn.ReLU(inplace=True)
322
+
323
+ # stage 2d
324
+
325
+ self.conv2d_1 = nn.Conv2d(256, 128, 3, padding=1) # 128
326
+ self.bn2d_1 = nn.BatchNorm2d(128)
327
+ self.relu2d_1 = nn.ReLU(inplace=True)
328
+
329
+ self.conv2d_m = nn.Conv2d(128, 128, 3, padding=1)
330
+ self.bn2d_m = nn.BatchNorm2d(128)
331
+ self.relu2d_m = nn.ReLU(inplace=True)
332
+
333
+ self.conv2d_2 = nn.Conv2d(128, 64, 3, padding=1)
334
+ self.bn2d_2 = nn.BatchNorm2d(64)
335
+ self.relu2d_2 = nn.ReLU(inplace=True)
336
+
337
+ # stage 1d
338
+ self.conv1d_1 = nn.Conv2d(128, 64, 3, padding=1) # 256
339
+ self.bn1d_1 = nn.BatchNorm2d(64)
340
+ self.relu1d_1 = nn.ReLU(inplace=True)
341
+
342
+ self.conv1d_m = nn.Conv2d(64, 64, 3, padding=1)
343
+ self.bn1d_m = nn.BatchNorm2d(64)
344
+ self.relu1d_m = nn.ReLU(inplace=True)
345
+
346
+ self.conv1d_2 = nn.Conv2d(64, 64, 3, padding=1)
347
+ self.bn1d_2 = nn.BatchNorm2d(64)
348
+ self.relu1d_2 = nn.ReLU(inplace=True)
349
+
350
+ # -------------Bilinear Upsampling--------------
351
+ self.upscore6 = nn.Upsample(
352
+ scale_factor=32, mode="bilinear", align_corners=False
353
+ )
354
+ self.upscore5 = nn.Upsample(
355
+ scale_factor=16, mode="bilinear", align_corners=False
356
+ )
357
+ self.upscore4 = nn.Upsample(
358
+ scale_factor=8, mode="bilinear", align_corners=False
359
+ )
360
+ self.upscore3 = nn.Upsample(
361
+ scale_factor=4, mode="bilinear", align_corners=False
362
+ )
363
+ self.upscore2 = nn.Upsample(
364
+ scale_factor=2, mode="bilinear", align_corners=False
365
+ )
366
+
367
+ # -------------Side Output--------------
368
+ self.outconvb = nn.Conv2d(512, 1, 3, padding=1)
369
+ self.outconv6 = nn.Conv2d(512, 1, 3, padding=1)
370
+ self.outconv5 = nn.Conv2d(512, 1, 3, padding=1)
371
+ self.outconv4 = nn.Conv2d(256, 1, 3, padding=1)
372
+ self.outconv3 = nn.Conv2d(128, 1, 3, padding=1)
373
+ self.outconv2 = nn.Conv2d(64, 1, 3, padding=1)
374
+ self.outconv1 = nn.Conv2d(64, 1, 3, padding=1)
375
+
376
+ # -------------Refine Module-------------
377
+ self.refunet = RefUnet(1, 64)
378
+
379
+ def forward(self, x):
380
+ hx = x
381
+
382
+ # -------------Encoder-------------
383
+ hx = self.inconv(hx)
384
+ hx = self.inbn(hx)
385
+ hx = self.inrelu(hx)
386
+
387
+ h1 = self.encoder1(hx) # 256
388
+ h2 = self.encoder2(h1) # 128
389
+ h3 = self.encoder3(h2) # 64
390
+ h4 = self.encoder4(h3) # 32
391
+
392
+ hx = self.pool4(h4) # 16
393
+
394
+ hx = self.resb5_1(hx)
395
+ hx = self.resb5_2(hx)
396
+ h5 = self.resb5_3(hx)
397
+
398
+ hx = self.pool5(h5) # 8
399
+
400
+ hx = self.resb6_1(hx)
401
+ hx = self.resb6_2(hx)
402
+ h6 = self.resb6_3(hx)
403
+
404
+ # -------------Bridge-------------
405
+ hx = self.relubg_1(self.bnbg_1(self.convbg_1(h6))) # 8
406
+ hx = self.relubg_m(self.bnbg_m(self.convbg_m(hx)))
407
+ hbg = self.relubg_2(self.bnbg_2(self.convbg_2(hx)))
408
+
409
+ # -------------Decoder-------------
410
+
411
+ hx = self.relu6d_1(self.bn6d_1(self.conv6d_1(torch.cat((hbg, h6), 1))))
412
+ hx = self.relu6d_m(self.bn6d_m(self.conv6d_m(hx)))
413
+ hd6 = self.relu6d_2(self.bn6d_2(self.conv6d_2(hx)))
414
+
415
+ hx = self.upscore2(hd6) # 8 -> 16
416
+
417
+ hx = self.relu5d_1(self.bn5d_1(self.conv5d_1(torch.cat((hx, h5), 1))))
418
+ hx = self.relu5d_m(self.bn5d_m(self.conv5d_m(hx)))
419
+ hd5 = self.relu5d_2(self.bn5d_2(self.conv5d_2(hx)))
420
+
421
+ hx = self.upscore2(hd5) # 16 -> 32
422
+
423
+ hx = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((hx, h4), 1))))
424
+ hx = self.relu4d_m(self.bn4d_m(self.conv4d_m(hx)))
425
+ hd4 = self.relu4d_2(self.bn4d_2(self.conv4d_2(hx)))
426
+
427
+ hx = self.upscore2(hd4) # 32 -> 64
428
+
429
+ hx = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((hx, h3), 1))))
430
+ hx = self.relu3d_m(self.bn3d_m(self.conv3d_m(hx)))
431
+ hd3 = self.relu3d_2(self.bn3d_2(self.conv3d_2(hx)))
432
+
433
+ hx = self.upscore2(hd3) # 64 -> 128
434
+
435
+ hx = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((hx, h2), 1))))
436
+ hx = self.relu2d_m(self.bn2d_m(self.conv2d_m(hx)))
437
+ hd2 = self.relu2d_2(self.bn2d_2(self.conv2d_2(hx)))
438
+
439
+ hx = self.upscore2(hd2) # 128 -> 256
440
+
441
+ hx = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((hx, h1), 1))))
442
+ hx = self.relu1d_m(self.bn1d_m(self.conv1d_m(hx)))
443
+ hd1 = self.relu1d_2(self.bn1d_2(self.conv1d_2(hx)))
444
+
445
+ # -------------Side Output-------------
446
+ db = self.outconvb(hbg)
447
+ db = self.upscore6(db) # 8->256
448
+
449
+ d6 = self.outconv6(hd6)
450
+ d6 = self.upscore6(d6) # 8->256
451
+
452
+ d5 = self.outconv5(hd5)
453
+ d5 = self.upscore5(d5) # 16->256
454
+
455
+ d4 = self.outconv4(hd4)
456
+ d4 = self.upscore4(d4) # 32->256
457
+
458
+ d3 = self.outconv3(hd3)
459
+ d3 = self.upscore3(d3) # 64->256
460
+
461
+ d2 = self.outconv2(hd2)
462
+ d2 = self.upscore2(d2) # 128->256
463
+
464
+ d1 = self.outconv1(hd1) # 256
465
+
466
+ # -------------Refine Module-------------
467
+ dout = self.refunet(d1) # 256
468
+
469
+ return (
470
+ torch.sigmoid(dout),
471
+ torch.sigmoid(d1),
472
+ torch.sigmoid(d2),
473
+ torch.sigmoid(d3),
474
+ torch.sigmoid(d4),
475
+ torch.sigmoid(d5),
476
+ torch.sigmoid(d6),
477
+ torch.sigmoid(db),
478
+ )
carvekit/ml/arch/fba_matting/__init__.py ADDED
File without changes
carvekit/ml/arch/fba_matting/layers_WS.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
3
+ Source url: https://github.com/MarcoForte/FBA_Matting
4
+ License: MIT License
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn import functional as F
9
+
10
+
11
+ class Conv2d(nn.Conv2d):
12
+ def __init__(
13
+ self,
14
+ in_channels,
15
+ out_channels,
16
+ kernel_size,
17
+ stride=1,
18
+ padding=0,
19
+ dilation=1,
20
+ groups=1,
21
+ bias=True,
22
+ ):
23
+ super(Conv2d, self).__init__(
24
+ in_channels,
25
+ out_channels,
26
+ kernel_size,
27
+ stride,
28
+ padding,
29
+ dilation,
30
+ groups,
31
+ bias,
32
+ )
33
+
34
+ def forward(self, x):
35
+ # return super(Conv2d, self).forward(x)
36
+ weight = self.weight
37
+ weight_mean = (
38
+ weight.mean(dim=1, keepdim=True)
39
+ .mean(dim=2, keepdim=True)
40
+ .mean(dim=3, keepdim=True)
41
+ )
42
+ weight = weight - weight_mean
43
+ # std = (weight).view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
44
+ std = (
45
+ torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(
46
+ -1, 1, 1, 1
47
+ )
48
+ + 1e-5
49
+ )
50
+ weight = weight / std.expand_as(weight)
51
+ return F.conv2d(
52
+ x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
53
+ )
54
+
55
+
56
+ def BatchNorm2d(num_features):
57
+ return nn.GroupNorm(num_channels=num_features, num_groups=32)
carvekit/ml/arch/fba_matting/models.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
3
+ Source url: https://github.com/MarcoForte/FBA_Matting
4
+ License: MIT License
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ import carvekit.ml.arch.fba_matting.resnet_GN_WS as resnet_GN_WS
9
+ import carvekit.ml.arch.fba_matting.layers_WS as L
10
+ import carvekit.ml.arch.fba_matting.resnet_bn as resnet_bn
11
+ from functools import partial
12
+
13
+
14
+ class FBA(nn.Module):
15
+ def __init__(self, encoder: str):
16
+ super(FBA, self).__init__()
17
+ self.encoder = build_encoder(arch=encoder)
18
+ self.decoder = fba_decoder(batch_norm=True if "BN" in encoder else False)
19
+
20
+ def forward(self, image, two_chan_trimap, image_n, trimap_transformed):
21
+ resnet_input = torch.cat((image_n, trimap_transformed, two_chan_trimap), 1)
22
+ conv_out, indices = self.encoder(resnet_input, return_feature_maps=True)
23
+ return self.decoder(conv_out, image, indices, two_chan_trimap)
24
+
25
+
26
+ class ResnetDilatedBN(nn.Module):
27
+ def __init__(self, orig_resnet, dilate_scale=8):
28
+ super(ResnetDilatedBN, self).__init__()
29
+
30
+ if dilate_scale == 8:
31
+ orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
32
+ orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
33
+ elif dilate_scale == 16:
34
+ orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
35
+
36
+ # take pretrained resnet, except AvgPool and FC
37
+ self.conv1 = orig_resnet.conv1
38
+ self.bn1 = orig_resnet.bn1
39
+ self.relu1 = orig_resnet.relu1
40
+ self.conv2 = orig_resnet.conv2
41
+ self.bn2 = orig_resnet.bn2
42
+ self.relu2 = orig_resnet.relu2
43
+ self.conv3 = orig_resnet.conv3
44
+ self.bn3 = orig_resnet.bn3
45
+ self.relu3 = orig_resnet.relu3
46
+ self.maxpool = orig_resnet.maxpool
47
+ self.layer1 = orig_resnet.layer1
48
+ self.layer2 = orig_resnet.layer2
49
+ self.layer3 = orig_resnet.layer3
50
+ self.layer4 = orig_resnet.layer4
51
+
52
+ def _nostride_dilate(self, m, dilate):
53
+ classname = m.__class__.__name__
54
+ if classname.find("Conv") != -1:
55
+ # the convolution with stride
56
+ if m.stride == (2, 2):
57
+ m.stride = (1, 1)
58
+ if m.kernel_size == (3, 3):
59
+ m.dilation = (dilate // 2, dilate // 2)
60
+ m.padding = (dilate // 2, dilate // 2)
61
+ # other convoluions
62
+ else:
63
+ if m.kernel_size == (3, 3):
64
+ m.dilation = (dilate, dilate)
65
+ m.padding = (dilate, dilate)
66
+
67
+ def forward(self, x, return_feature_maps=False):
68
+ conv_out = [x]
69
+ x = self.relu1(self.bn1(self.conv1(x)))
70
+ x = self.relu2(self.bn2(self.conv2(x)))
71
+ x = self.relu3(self.bn3(self.conv3(x)))
72
+ conv_out.append(x)
73
+ x, indices = self.maxpool(x)
74
+ x = self.layer1(x)
75
+ conv_out.append(x)
76
+ x = self.layer2(x)
77
+ conv_out.append(x)
78
+ x = self.layer3(x)
79
+ conv_out.append(x)
80
+ x = self.layer4(x)
81
+ conv_out.append(x)
82
+
83
+ if return_feature_maps:
84
+ return conv_out, indices
85
+ return [x]
86
+
87
+
88
+ class Resnet(nn.Module):
89
+ def __init__(self, orig_resnet):
90
+ super(Resnet, self).__init__()
91
+
92
+ # take pretrained resnet, except AvgPool and FC
93
+ self.conv1 = orig_resnet.conv1
94
+ self.bn1 = orig_resnet.bn1
95
+ self.relu1 = orig_resnet.relu1
96
+ self.conv2 = orig_resnet.conv2
97
+ self.bn2 = orig_resnet.bn2
98
+ self.relu2 = orig_resnet.relu2
99
+ self.conv3 = orig_resnet.conv3
100
+ self.bn3 = orig_resnet.bn3
101
+ self.relu3 = orig_resnet.relu3
102
+ self.maxpool = orig_resnet.maxpool
103
+ self.layer1 = orig_resnet.layer1
104
+ self.layer2 = orig_resnet.layer2
105
+ self.layer3 = orig_resnet.layer3
106
+ self.layer4 = orig_resnet.layer4
107
+
108
+ def forward(self, x, return_feature_maps=False):
109
+ conv_out = []
110
+
111
+ x = self.relu1(self.bn1(self.conv1(x)))
112
+ x = self.relu2(self.bn2(self.conv2(x)))
113
+ x = self.relu3(self.bn3(self.conv3(x)))
114
+ conv_out.append(x)
115
+ x, indices = self.maxpool(x)
116
+
117
+ x = self.layer1(x)
118
+ conv_out.append(x)
119
+ x = self.layer2(x)
120
+ conv_out.append(x)
121
+ x = self.layer3(x)
122
+ conv_out.append(x)
123
+ x = self.layer4(x)
124
+ conv_out.append(x)
125
+
126
+ if return_feature_maps:
127
+ return conv_out
128
+ return [x]
129
+
130
+
131
+ class ResnetDilated(nn.Module):
132
+ def __init__(self, orig_resnet, dilate_scale=8):
133
+ super(ResnetDilated, self).__init__()
134
+
135
+ if dilate_scale == 8:
136
+ orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
137
+ orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
138
+ elif dilate_scale == 16:
139
+ orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
140
+
141
+ # take pretrained resnet, except AvgPool and FC
142
+ self.conv1 = orig_resnet.conv1
143
+ self.bn1 = orig_resnet.bn1
144
+ self.relu = orig_resnet.relu
145
+ self.maxpool = orig_resnet.maxpool
146
+ self.layer1 = orig_resnet.layer1
147
+ self.layer2 = orig_resnet.layer2
148
+ self.layer3 = orig_resnet.layer3
149
+ self.layer4 = orig_resnet.layer4
150
+
151
+ def _nostride_dilate(self, m, dilate):
152
+ classname = m.__class__.__name__
153
+ if classname.find("Conv") != -1:
154
+ # the convolution with stride
155
+ if m.stride == (2, 2):
156
+ m.stride = (1, 1)
157
+ if m.kernel_size == (3, 3):
158
+ m.dilation = (dilate // 2, dilate // 2)
159
+ m.padding = (dilate // 2, dilate // 2)
160
+ # other convoluions
161
+ else:
162
+ if m.kernel_size == (3, 3):
163
+ m.dilation = (dilate, dilate)
164
+ m.padding = (dilate, dilate)
165
+
166
+ def forward(self, x, return_feature_maps=False):
167
+ conv_out = [x]
168
+ x = self.relu(self.bn1(self.conv1(x)))
169
+ conv_out.append(x)
170
+ x, indices = self.maxpool(x)
171
+ x = self.layer1(x)
172
+ conv_out.append(x)
173
+ x = self.layer2(x)
174
+ conv_out.append(x)
175
+ x = self.layer3(x)
176
+ conv_out.append(x)
177
+ x = self.layer4(x)
178
+ conv_out.append(x)
179
+
180
+ if return_feature_maps:
181
+ return conv_out, indices
182
+ return [x]
183
+
184
+
185
+ def norm(dim, bn=False):
186
+ if bn is False:
187
+ return nn.GroupNorm(32, dim)
188
+ else:
189
+ return nn.BatchNorm2d(dim)
190
+
191
+
192
+ def fba_fusion(alpha, img, F, B):
193
+ F = alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B
194
+ B = (1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F
195
+
196
+ F = torch.clamp(F, 0, 1)
197
+ B = torch.clamp(B, 0, 1)
198
+ la = 0.1
199
+ alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (
200
+ torch.sum((F - B) * (F - B), 1, keepdim=True) + la
201
+ )
202
+ alpha = torch.clamp(alpha, 0, 1)
203
+ return alpha, F, B
204
+
205
+
206
+ class fba_decoder(nn.Module):
207
+ def __init__(self, batch_norm=False):
208
+ super(fba_decoder, self).__init__()
209
+ pool_scales = (1, 2, 3, 6)
210
+ self.batch_norm = batch_norm
211
+
212
+ self.ppm = []
213
+
214
+ for scale in pool_scales:
215
+ self.ppm.append(
216
+ nn.Sequential(
217
+ nn.AdaptiveAvgPool2d(scale),
218
+ L.Conv2d(2048, 256, kernel_size=1, bias=True),
219
+ norm(256, self.batch_norm),
220
+ nn.LeakyReLU(),
221
+ )
222
+ )
223
+ self.ppm = nn.ModuleList(self.ppm)
224
+
225
+ self.conv_up1 = nn.Sequential(
226
+ L.Conv2d(
227
+ 2048 + len(pool_scales) * 256, 256, kernel_size=3, padding=1, bias=True
228
+ ),
229
+ norm(256, self.batch_norm),
230
+ nn.LeakyReLU(),
231
+ L.Conv2d(256, 256, kernel_size=3, padding=1),
232
+ norm(256, self.batch_norm),
233
+ nn.LeakyReLU(),
234
+ )
235
+
236
+ self.conv_up2 = nn.Sequential(
237
+ L.Conv2d(256 + 256, 256, kernel_size=3, padding=1, bias=True),
238
+ norm(256, self.batch_norm),
239
+ nn.LeakyReLU(),
240
+ )
241
+ if self.batch_norm:
242
+ d_up3 = 128
243
+ else:
244
+ d_up3 = 64
245
+ self.conv_up3 = nn.Sequential(
246
+ L.Conv2d(256 + d_up3, 64, kernel_size=3, padding=1, bias=True),
247
+ norm(64, self.batch_norm),
248
+ nn.LeakyReLU(),
249
+ )
250
+
251
+ self.unpool = nn.MaxUnpool2d(2, stride=2)
252
+
253
+ self.conv_up4 = nn.Sequential(
254
+ nn.Conv2d(64 + 3 + 3 + 2, 32, kernel_size=3, padding=1, bias=True),
255
+ nn.LeakyReLU(),
256
+ nn.Conv2d(32, 16, kernel_size=3, padding=1, bias=True),
257
+ nn.LeakyReLU(),
258
+ nn.Conv2d(16, 7, kernel_size=1, padding=0, bias=True),
259
+ )
260
+
261
+ def forward(self, conv_out, img, indices, two_chan_trimap):
262
+ conv5 = conv_out[-1]
263
+
264
+ input_size = conv5.size()
265
+ ppm_out = [conv5]
266
+ for pool_scale in self.ppm:
267
+ ppm_out.append(
268
+ nn.functional.interpolate(
269
+ pool_scale(conv5),
270
+ (input_size[2], input_size[3]),
271
+ mode="bilinear",
272
+ align_corners=False,
273
+ )
274
+ )
275
+ ppm_out = torch.cat(ppm_out, 1)
276
+ x = self.conv_up1(ppm_out)
277
+
278
+ x = torch.nn.functional.interpolate(
279
+ x, scale_factor=2, mode="bilinear", align_corners=False
280
+ )
281
+
282
+ x = torch.cat((x, conv_out[-4]), 1)
283
+
284
+ x = self.conv_up2(x)
285
+ x = torch.nn.functional.interpolate(
286
+ x, scale_factor=2, mode="bilinear", align_corners=False
287
+ )
288
+
289
+ x = torch.cat((x, conv_out[-5]), 1)
290
+ x = self.conv_up3(x)
291
+
292
+ x = torch.nn.functional.interpolate(
293
+ x, scale_factor=2, mode="bilinear", align_corners=False
294
+ )
295
+ x = torch.cat((x, conv_out[-6][:, :3], img, two_chan_trimap), 1)
296
+
297
+ output = self.conv_up4(x)
298
+
299
+ alpha = torch.clamp(output[:, 0][:, None], 0, 1)
300
+ F = torch.sigmoid(output[:, 1:4])
301
+ B = torch.sigmoid(output[:, 4:7])
302
+
303
+ # FBA Fusion
304
+ alpha, F, B = fba_fusion(alpha, img, F, B)
305
+
306
+ output = torch.cat((alpha, F, B), 1)
307
+
308
+ return output
309
+
310
+
311
+ def build_encoder(arch="resnet50_GN"):
312
+ if arch == "resnet50_GN_WS":
313
+ orig_resnet = resnet_GN_WS.__dict__["l_resnet50"]()
314
+ net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
315
+ elif arch == "resnet50_BN":
316
+ orig_resnet = resnet_bn.__dict__["l_resnet50"]()
317
+ net_encoder = ResnetDilatedBN(orig_resnet, dilate_scale=8)
318
+
319
+ else:
320
+ raise ValueError("Architecture undefined!")
321
+
322
+ num_channels = 3 + 6 + 2
323
+
324
+ if num_channels > 3:
325
+ net_encoder_sd = net_encoder.state_dict()
326
+ conv1_weights = net_encoder_sd["conv1.weight"]
327
+
328
+ c_out, c_in, h, w = conv1_weights.size()
329
+ conv1_mod = torch.zeros(c_out, num_channels, h, w)
330
+ conv1_mod[:, :3, :, :] = conv1_weights
331
+
332
+ conv1 = net_encoder.conv1
333
+ conv1.in_channels = num_channels
334
+ conv1.weight = torch.nn.Parameter(conv1_mod)
335
+
336
+ net_encoder.conv1 = conv1
337
+
338
+ net_encoder_sd["conv1.weight"] = conv1_mod
339
+
340
+ net_encoder.load_state_dict(net_encoder_sd)
341
+ return net_encoder
carvekit/ml/arch/fba_matting/resnet_GN_WS.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
3
+ Source url: https://github.com/MarcoForte/FBA_Matting
4
+ License: MIT License
5
+ """
6
+ import torch.nn as nn
7
+ import carvekit.ml.arch.fba_matting.layers_WS as L
8
+
9
+ __all__ = ["ResNet", "l_resnet50"]
10
+
11
+
12
+ def conv3x3(in_planes, out_planes, stride=1):
13
+ """3x3 convolution with padding"""
14
+ return L.Conv2d(
15
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
16
+ )
17
+
18
+
19
+ def conv1x1(in_planes, out_planes, stride=1):
20
+ """1x1 convolution"""
21
+ return L.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
22
+
23
+
24
+ class BasicBlock(nn.Module):
25
+ expansion = 1
26
+
27
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
28
+ super(BasicBlock, self).__init__()
29
+ self.conv1 = conv3x3(inplanes, planes, stride)
30
+ self.bn1 = L.BatchNorm2d(planes)
31
+ self.relu = nn.ReLU(inplace=True)
32
+ self.conv2 = conv3x3(planes, planes)
33
+ self.bn2 = L.BatchNorm2d(planes)
34
+ self.downsample = downsample
35
+ self.stride = stride
36
+
37
+ def forward(self, x):
38
+ identity = x
39
+
40
+ out = self.conv1(x)
41
+ out = self.bn1(out)
42
+ out = self.relu(out)
43
+
44
+ out = self.conv2(out)
45
+ out = self.bn2(out)
46
+
47
+ if self.downsample is not None:
48
+ identity = self.downsample(x)
49
+
50
+ out += identity
51
+ out = self.relu(out)
52
+
53
+ return out
54
+
55
+
56
+ class Bottleneck(nn.Module):
57
+ expansion = 4
58
+
59
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
60
+ super(Bottleneck, self).__init__()
61
+ self.conv1 = conv1x1(inplanes, planes)
62
+ self.bn1 = L.BatchNorm2d(planes)
63
+ self.conv2 = conv3x3(planes, planes, stride)
64
+ self.bn2 = L.BatchNorm2d(planes)
65
+ self.conv3 = conv1x1(planes, planes * self.expansion)
66
+ self.bn3 = L.BatchNorm2d(planes * self.expansion)
67
+ self.relu = nn.ReLU(inplace=True)
68
+ self.downsample = downsample
69
+ self.stride = stride
70
+
71
+ def forward(self, x):
72
+ identity = x
73
+
74
+ out = self.conv1(x)
75
+ out = self.bn1(out)
76
+ out = self.relu(out)
77
+
78
+ out = self.conv2(out)
79
+ out = self.bn2(out)
80
+ out = self.relu(out)
81
+
82
+ out = self.conv3(out)
83
+ out = self.bn3(out)
84
+
85
+ if self.downsample is not None:
86
+ identity = self.downsample(x)
87
+
88
+ out += identity
89
+ out = self.relu(out)
90
+
91
+ return out
92
+
93
+
94
+ class ResNet(nn.Module):
95
+ def __init__(self, block, layers, num_classes=1000):
96
+ super(ResNet, self).__init__()
97
+ self.inplanes = 64
98
+ self.conv1 = L.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
99
+ self.bn1 = L.BatchNorm2d(64)
100
+ self.relu = nn.ReLU(inplace=True)
101
+ self.maxpool = nn.MaxPool2d(
102
+ kernel_size=3, stride=2, padding=1, return_indices=True
103
+ )
104
+ self.layer1 = self._make_layer(block, 64, layers[0])
105
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
106
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
107
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
108
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
109
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
110
+
111
+ def _make_layer(self, block, planes, blocks, stride=1):
112
+ downsample = None
113
+ if stride != 1 or self.inplanes != planes * block.expansion:
114
+ downsample = nn.Sequential(
115
+ conv1x1(self.inplanes, planes * block.expansion, stride),
116
+ L.BatchNorm2d(planes * block.expansion),
117
+ )
118
+
119
+ layers = []
120
+ layers.append(block(self.inplanes, planes, stride, downsample))
121
+ self.inplanes = planes * block.expansion
122
+ for _ in range(1, blocks):
123
+ layers.append(block(self.inplanes, planes))
124
+
125
+ return nn.Sequential(*layers)
126
+
127
+ def forward(self, x):
128
+ x = self.conv1(x)
129
+ x = self.bn1(x)
130
+ x = self.relu(x)
131
+ x = self.maxpool(x)
132
+
133
+ x = self.layer1(x)
134
+ x = self.layer2(x)
135
+ x = self.layer3(x)
136
+ x = self.layer4(x)
137
+
138
+ x = self.avgpool(x)
139
+ x = x.view(x.size(0), -1)
140
+ x = self.fc(x)
141
+
142
+ return x
143
+
144
+
145
+ def l_resnet50(pretrained=False, **kwargs):
146
+ """Constructs a ResNet-50 model.
147
+ Args:
148
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
149
+ """
150
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
151
+ return model
carvekit/ml/arch/fba_matting/resnet_bn.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
3
+ Source url: https://github.com/MarcoForte/FBA_Matting
4
+ License: MIT License
5
+ """
6
+ import torch.nn as nn
7
+ import math
8
+ from torch.nn import BatchNorm2d
9
+
10
+ __all__ = ["ResNet"]
11
+
12
+
13
+ def conv3x3(in_planes, out_planes, stride=1):
14
+ "3x3 convolution with padding"
15
+ return nn.Conv2d(
16
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
17
+ )
18
+
19
+
20
+ class BasicBlock(nn.Module):
21
+ expansion = 1
22
+
23
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
24
+ super(BasicBlock, self).__init__()
25
+ self.conv1 = conv3x3(inplanes, planes, stride)
26
+ self.bn1 = BatchNorm2d(planes)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.conv2 = conv3x3(planes, planes)
29
+ self.bn2 = BatchNorm2d(planes)
30
+ self.downsample = downsample
31
+ self.stride = stride
32
+
33
+ def forward(self, x):
34
+ residual = x
35
+
36
+ out = self.conv1(x)
37
+ out = self.bn1(out)
38
+ out = self.relu(out)
39
+
40
+ out = self.conv2(out)
41
+ out = self.bn2(out)
42
+
43
+ if self.downsample is not None:
44
+ residual = self.downsample(x)
45
+
46
+ out += residual
47
+ out = self.relu(out)
48
+
49
+ return out
50
+
51
+
52
+ class Bottleneck(nn.Module):
53
+ expansion = 4
54
+
55
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
56
+ super(Bottleneck, self).__init__()
57
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
58
+ self.bn1 = BatchNorm2d(planes)
59
+ self.conv2 = nn.Conv2d(
60
+ planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
61
+ )
62
+ self.bn2 = BatchNorm2d(planes, momentum=0.01)
63
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
64
+ self.bn3 = BatchNorm2d(planes * 4)
65
+ self.relu = nn.ReLU(inplace=True)
66
+ self.downsample = downsample
67
+ self.stride = stride
68
+
69
+ def forward(self, x):
70
+ residual = x
71
+
72
+ out = self.conv1(x)
73
+ out = self.bn1(out)
74
+ out = self.relu(out)
75
+
76
+ out = self.conv2(out)
77
+ out = self.bn2(out)
78
+ out = self.relu(out)
79
+
80
+ out = self.conv3(out)
81
+ out = self.bn3(out)
82
+
83
+ if self.downsample is not None:
84
+ residual = self.downsample(x)
85
+
86
+ out += residual
87
+ out = self.relu(out)
88
+
89
+ return out
90
+
91
+
92
+ class ResNet(nn.Module):
93
+ def __init__(self, block, layers, num_classes=1000):
94
+ self.inplanes = 128
95
+ super(ResNet, self).__init__()
96
+ self.conv1 = conv3x3(3, 64, stride=2)
97
+ self.bn1 = BatchNorm2d(64)
98
+ self.relu1 = nn.ReLU(inplace=True)
99
+ self.conv2 = conv3x3(64, 64)
100
+ self.bn2 = BatchNorm2d(64)
101
+ self.relu2 = nn.ReLU(inplace=True)
102
+ self.conv3 = conv3x3(64, 128)
103
+ self.bn3 = BatchNorm2d(128)
104
+ self.relu3 = nn.ReLU(inplace=True)
105
+ self.maxpool = nn.MaxPool2d(
106
+ kernel_size=3, stride=2, padding=1, return_indices=True
107
+ )
108
+
109
+ self.layer1 = self._make_layer(block, 64, layers[0])
110
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
111
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
112
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
113
+ self.avgpool = nn.AvgPool2d(7, stride=1)
114
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
115
+
116
+ for m in self.modules():
117
+ if isinstance(m, nn.Conv2d):
118
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
119
+ m.weight.data.normal_(0, math.sqrt(2.0 / n))
120
+ elif isinstance(m, BatchNorm2d):
121
+ m.weight.data.fill_(1)
122
+ m.bias.data.zero_()
123
+
124
+ def _make_layer(self, block, planes, blocks, stride=1):
125
+ downsample = None
126
+ if stride != 1 or self.inplanes != planes * block.expansion:
127
+ downsample = nn.Sequential(
128
+ nn.Conv2d(
129
+ self.inplanes,
130
+ planes * block.expansion,
131
+ kernel_size=1,
132
+ stride=stride,
133
+ bias=False,
134
+ ),
135
+ BatchNorm2d(planes * block.expansion),
136
+ )
137
+
138
+ layers = []
139
+ layers.append(block(self.inplanes, planes, stride, downsample))
140
+ self.inplanes = planes * block.expansion
141
+ for i in range(1, blocks):
142
+ layers.append(block(self.inplanes, planes))
143
+
144
+ return nn.Sequential(*layers)
145
+
146
+ def forward(self, x):
147
+ x = self.relu1(self.bn1(self.conv1(x)))
148
+ x = self.relu2(self.bn2(self.conv2(x)))
149
+ x = self.relu3(self.bn3(self.conv3(x)))
150
+ x, indices = self.maxpool(x)
151
+
152
+ x = self.layer1(x)
153
+ x = self.layer2(x)
154
+ x = self.layer3(x)
155
+ x = self.layer4(x)
156
+
157
+ x = self.avgpool(x)
158
+ x = x.view(x.size(0), -1)
159
+ x = self.fc(x)
160
+ return x
161
+
162
+
163
+ def l_resnet50():
164
+ """Constructs a ResNet-50 model.
165
+ Args:
166
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
167
+ """
168
+ model = ResNet(Bottleneck, [3, 4, 6, 3])
169
+ return model
carvekit/ml/arch/fba_matting/transforms.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
3
+ Source url: https://github.com/MarcoForte/FBA_Matting
4
+ License: MIT License
5
+ """
6
+ import cv2
7
+ import numpy as np
8
+
9
+ group_norm_std = [0.229, 0.224, 0.225]
10
+ group_norm_mean = [0.485, 0.456, 0.406]
11
+
12
+
13
+ def dt(a):
14
+ return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0)
15
+
16
+
17
+ def trimap_transform(trimap):
18
+ h, w = trimap.shape[0], trimap.shape[1]
19
+
20
+ clicks = np.zeros((h, w, 6))
21
+ for k in range(2):
22
+ if np.count_nonzero(trimap[:, :, k]) > 0:
23
+ dt_mask = -dt(1 - trimap[:, :, k]) ** 2
24
+ L = 320
25
+ clicks[:, :, 3 * k] = np.exp(dt_mask / (2 * ((0.02 * L) ** 2)))
26
+ clicks[:, :, 3 * k + 1] = np.exp(dt_mask / (2 * ((0.08 * L) ** 2)))
27
+ clicks[:, :, 3 * k + 2] = np.exp(dt_mask / (2 * ((0.16 * L) ** 2)))
28
+
29
+ return clicks
30
+
31
+
32
+ def groupnorm_normalise_image(img, format="nhwc"):
33
+ """
34
+ Accept rgb in range 0,1
35
+ """
36
+ if format == "nhwc":
37
+ for i in range(3):
38
+ img[..., i] = (img[..., i] - group_norm_mean[i]) / group_norm_std[i]
39
+ else:
40
+ for i in range(3):
41
+ img[..., i, :, :] = (
42
+ img[..., i, :, :] - group_norm_mean[i]
43
+ ) / group_norm_std[i]
44
+
45
+ return img
carvekit/ml/arch/tracerb7/__init__.py ADDED
File without changes
carvekit/ml/arch/tracerb7/att_modules.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/Karel911/TRACER
3
+ Author: Min Seok Lee and Wooseok Shin
4
+ License: Apache License 2.0
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from carvekit.ml.arch.tracerb7.conv_modules import BasicConv2d, DWConv, DWSConv
11
+
12
+
13
+ class RFB_Block(nn.Module):
14
+ def __init__(self, in_channel, out_channel):
15
+ super(RFB_Block, self).__init__()
16
+ self.relu = nn.ReLU(True)
17
+ self.branch0 = nn.Sequential(
18
+ BasicConv2d(in_channel, out_channel, 1),
19
+ )
20
+ self.branch1 = nn.Sequential(
21
+ BasicConv2d(in_channel, out_channel, 1),
22
+ BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
23
+ BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
24
+ BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3),
25
+ )
26
+ self.branch2 = nn.Sequential(
27
+ BasicConv2d(in_channel, out_channel, 1),
28
+ BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
29
+ BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
30
+ BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5),
31
+ )
32
+ self.branch3 = nn.Sequential(
33
+ BasicConv2d(in_channel, out_channel, 1),
34
+ BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
35
+ BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
36
+ BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7),
37
+ )
38
+ self.conv_cat = BasicConv2d(4 * out_channel, out_channel, 3, padding=1)
39
+ self.conv_res = BasicConv2d(in_channel, out_channel, 1)
40
+
41
+ def forward(self, x):
42
+ x0 = self.branch0(x)
43
+ x1 = self.branch1(x)
44
+ x2 = self.branch2(x)
45
+ x3 = self.branch3(x)
46
+ x_cat = torch.cat((x0, x1, x2, x3), 1)
47
+ x_cat = self.conv_cat(x_cat)
48
+
49
+ x = self.relu(x_cat + self.conv_res(x))
50
+ return x
51
+
52
+
53
+ class GlobalAvgPool(nn.Module):
54
+ def __init__(self, flatten=False):
55
+ super(GlobalAvgPool, self).__init__()
56
+ self.flatten = flatten
57
+
58
+ def forward(self, x):
59
+ if self.flatten:
60
+ in_size = x.size()
61
+ return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
62
+ else:
63
+ return (
64
+ x.view(x.size(0), x.size(1), -1)
65
+ .mean(-1)
66
+ .view(x.size(0), x.size(1), 1, 1)
67
+ )
68
+
69
+
70
+ class UnionAttentionModule(nn.Module):
71
+ def __init__(self, n_channels, only_channel_tracing=False):
72
+ super(UnionAttentionModule, self).__init__()
73
+ self.GAP = GlobalAvgPool()
74
+ self.confidence_ratio = 0.1
75
+ self.bn = nn.BatchNorm2d(n_channels)
76
+ self.norm = nn.Sequential(
77
+ nn.BatchNorm2d(n_channels), nn.Dropout3d(self.confidence_ratio)
78
+ )
79
+ self.channel_q = nn.Conv2d(
80
+ in_channels=n_channels,
81
+ out_channels=n_channels,
82
+ kernel_size=1,
83
+ stride=1,
84
+ padding=0,
85
+ bias=False,
86
+ )
87
+ self.channel_k = nn.Conv2d(
88
+ in_channels=n_channels,
89
+ out_channels=n_channels,
90
+ kernel_size=1,
91
+ stride=1,
92
+ padding=0,
93
+ bias=False,
94
+ )
95
+ self.channel_v = nn.Conv2d(
96
+ in_channels=n_channels,
97
+ out_channels=n_channels,
98
+ kernel_size=1,
99
+ stride=1,
100
+ padding=0,
101
+ bias=False,
102
+ )
103
+
104
+ self.fc = nn.Conv2d(
105
+ in_channels=n_channels,
106
+ out_channels=n_channels,
107
+ kernel_size=1,
108
+ stride=1,
109
+ padding=0,
110
+ bias=False,
111
+ )
112
+
113
+ if only_channel_tracing is False:
114
+ self.spatial_q = nn.Conv2d(
115
+ in_channels=n_channels,
116
+ out_channels=1,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0,
120
+ bias=False,
121
+ )
122
+ self.spatial_k = nn.Conv2d(
123
+ in_channels=n_channels,
124
+ out_channels=1,
125
+ kernel_size=1,
126
+ stride=1,
127
+ padding=0,
128
+ bias=False,
129
+ )
130
+ self.spatial_v = nn.Conv2d(
131
+ in_channels=n_channels,
132
+ out_channels=1,
133
+ kernel_size=1,
134
+ stride=1,
135
+ padding=0,
136
+ bias=False,
137
+ )
138
+ self.sigmoid = nn.Sigmoid()
139
+
140
+ def masking(self, x, mask):
141
+ mask = mask.squeeze(3).squeeze(2)
142
+ threshold = torch.quantile(
143
+ mask.float(), self.confidence_ratio, dim=-1, keepdim=True
144
+ )
145
+ mask[mask <= threshold] = 0.0
146
+ mask = mask.unsqueeze(2).unsqueeze(3)
147
+ mask = mask.expand(-1, x.shape[1], x.shape[2], x.shape[3]).contiguous()
148
+ masked_x = x * mask
149
+
150
+ return masked_x
151
+
152
+ def Channel_Tracer(self, x):
153
+ avg_pool = self.GAP(x)
154
+ x_norm = self.norm(avg_pool)
155
+
156
+ q = self.channel_q(x_norm).squeeze(-1)
157
+ k = self.channel_k(x_norm).squeeze(-1)
158
+ v = self.channel_v(x_norm).squeeze(-1)
159
+
160
+ # softmax(Q*K^T)
161
+ QK_T = torch.matmul(q, k.transpose(1, 2))
162
+ alpha = F.softmax(QK_T, dim=-1)
163
+
164
+ # a*v
165
+ att = torch.matmul(alpha, v).unsqueeze(-1)
166
+ att = self.fc(att)
167
+ att = self.sigmoid(att)
168
+
169
+ output = (x * att) + x
170
+ alpha_mask = att.clone()
171
+
172
+ return output, alpha_mask
173
+
174
+ def forward(self, x):
175
+ X_c, alpha_mask = self.Channel_Tracer(x)
176
+ X_c = self.bn(X_c)
177
+ x_drop = self.masking(X_c, alpha_mask)
178
+
179
+ q = self.spatial_q(x_drop).squeeze(1)
180
+ k = self.spatial_k(x_drop).squeeze(1)
181
+ v = self.spatial_v(x_drop).squeeze(1)
182
+
183
+ # softmax(Q*K^T)
184
+ QK_T = torch.matmul(q, k.transpose(1, 2))
185
+ alpha = F.softmax(QK_T, dim=-1)
186
+
187
+ output = torch.matmul(alpha, v).unsqueeze(1) + v.unsqueeze(1)
188
+
189
+ return output
190
+
191
+
192
+ class aggregation(nn.Module):
193
+ def __init__(self, channel):
194
+ super(aggregation, self).__init__()
195
+ self.relu = nn.ReLU(True)
196
+
197
+ self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
198
+ self.conv_upsample1 = BasicConv2d(channel[2], channel[1], 3, padding=1)
199
+ self.conv_upsample2 = BasicConv2d(channel[2], channel[0], 3, padding=1)
200
+ self.conv_upsample3 = BasicConv2d(channel[1], channel[0], 3, padding=1)
201
+ self.conv_upsample4 = BasicConv2d(channel[2], channel[2], 3, padding=1)
202
+ self.conv_upsample5 = BasicConv2d(
203
+ channel[2] + channel[1], channel[2] + channel[1], 3, padding=1
204
+ )
205
+
206
+ self.conv_concat2 = BasicConv2d(
207
+ (channel[2] + channel[1]), (channel[2] + channel[1]), 3, padding=1
208
+ )
209
+ self.conv_concat3 = BasicConv2d(
210
+ (channel[0] + channel[1] + channel[2]),
211
+ (channel[0] + channel[1] + channel[2]),
212
+ 3,
213
+ padding=1,
214
+ )
215
+
216
+ self.UAM = UnionAttentionModule(channel[0] + channel[1] + channel[2])
217
+
218
+ def forward(self, e4, e3, e2):
219
+ e4_1 = e4
220
+ e3_1 = self.conv_upsample1(self.upsample(e4)) * e3
221
+ e2_1 = (
222
+ self.conv_upsample2(self.upsample(self.upsample(e4)))
223
+ * self.conv_upsample3(self.upsample(e3))
224
+ * e2
225
+ )
226
+
227
+ e3_2 = torch.cat((e3_1, self.conv_upsample4(self.upsample(e4_1))), 1)
228
+ e3_2 = self.conv_concat2(e3_2)
229
+
230
+ e2_2 = torch.cat((e2_1, self.conv_upsample5(self.upsample(e3_2))), 1)
231
+ x = self.conv_concat3(e2_2)
232
+
233
+ output = self.UAM(x)
234
+
235
+ return output
236
+
237
+
238
+ class ObjectAttention(nn.Module):
239
+ def __init__(self, channel, kernel_size):
240
+ super(ObjectAttention, self).__init__()
241
+ self.channel = channel
242
+ self.DWSConv = DWSConv(
243
+ channel, channel // 2, kernel=kernel_size, padding=1, kernels_per_layer=1
244
+ )
245
+ self.DWConv1 = nn.Sequential(
246
+ DWConv(channel // 2, channel // 2, kernel=1, padding=0, dilation=1),
247
+ BasicConv2d(channel // 2, channel // 8, 1),
248
+ )
249
+ self.DWConv2 = nn.Sequential(
250
+ DWConv(channel // 2, channel // 2, kernel=3, padding=1, dilation=1),
251
+ BasicConv2d(channel // 2, channel // 8, 1),
252
+ )
253
+ self.DWConv3 = nn.Sequential(
254
+ DWConv(channel // 2, channel // 2, kernel=3, padding=3, dilation=3),
255
+ BasicConv2d(channel // 2, channel // 8, 1),
256
+ )
257
+ self.DWConv4 = nn.Sequential(
258
+ DWConv(channel // 2, channel // 2, kernel=3, padding=5, dilation=5),
259
+ BasicConv2d(channel // 2, channel // 8, 1),
260
+ )
261
+ self.conv1 = BasicConv2d(channel // 2, 1, 1)
262
+
263
+ def forward(self, decoder_map, encoder_map):
264
+ """
265
+ Args:
266
+ decoder_map: decoder representation (B, 1, H, W).
267
+ encoder_map: encoder block output (B, C, H, W).
268
+ Returns:
269
+ decoder representation: (B, 1, H, W)
270
+ """
271
+ mask_bg = -1 * torch.sigmoid(decoder_map) + 1 # Sigmoid & Reverse
272
+ mask_ob = torch.sigmoid(decoder_map) # object attention
273
+ x = mask_ob.expand(-1, self.channel, -1, -1).mul(encoder_map)
274
+
275
+ edge = mask_bg.clone()
276
+ edge[edge > 0.93] = 0
277
+ x = x + (edge * encoder_map)
278
+
279
+ x = self.DWSConv(x)
280
+ skip = x.clone()
281
+ x = (
282
+ torch.cat(
283
+ [self.DWConv1(x), self.DWConv2(x), self.DWConv3(x), self.DWConv4(x)],
284
+ dim=1,
285
+ )
286
+ + skip
287
+ )
288
+ x = torch.relu(self.conv1(x))
289
+
290
+ return x + decoder_map
carvekit/ml/arch/tracerb7/conv_modules.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/Karel911/TRACER
3
+ Author: Min Seok Lee and Wooseok Shin
4
+ License: Apache License 2.0
5
+ """
6
+ import torch.nn as nn
7
+
8
+
9
+ class BasicConv2d(nn.Module):
10
+ def __init__(
11
+ self,
12
+ in_channel,
13
+ out_channel,
14
+ kernel_size,
15
+ stride=(1, 1),
16
+ padding=(0, 0),
17
+ dilation=(1, 1),
18
+ ):
19
+ super(BasicConv2d, self).__init__()
20
+ self.conv = nn.Conv2d(
21
+ in_channel,
22
+ out_channel,
23
+ kernel_size=kernel_size,
24
+ stride=stride,
25
+ padding=padding,
26
+ dilation=dilation,
27
+ bias=False,
28
+ )
29
+ self.bn = nn.BatchNorm2d(out_channel)
30
+ self.selu = nn.SELU()
31
+
32
+ def forward(self, x):
33
+ x = self.conv(x)
34
+ x = self.bn(x)
35
+ x = self.selu(x)
36
+
37
+ return x
38
+
39
+
40
+ class DWConv(nn.Module):
41
+ def __init__(self, in_channel, out_channel, kernel, dilation, padding):
42
+ super(DWConv, self).__init__()
43
+ self.out_channel = out_channel
44
+ self.DWConv = nn.Conv2d(
45
+ in_channel,
46
+ out_channel,
47
+ kernel_size=kernel,
48
+ padding=padding,
49
+ groups=in_channel,
50
+ dilation=dilation,
51
+ bias=False,
52
+ )
53
+ self.bn = nn.BatchNorm2d(out_channel)
54
+ self.selu = nn.SELU()
55
+
56
+ def forward(self, x):
57
+ x = self.DWConv(x)
58
+ out = self.selu(self.bn(x))
59
+
60
+ return out
61
+
62
+
63
+ class DWSConv(nn.Module):
64
+ def __init__(self, in_channel, out_channel, kernel, padding, kernels_per_layer):
65
+ super(DWSConv, self).__init__()
66
+ self.out_channel = out_channel
67
+ self.DWConv = nn.Conv2d(
68
+ in_channel,
69
+ in_channel * kernels_per_layer,
70
+ kernel_size=kernel,
71
+ padding=padding,
72
+ groups=in_channel,
73
+ bias=False,
74
+ )
75
+ self.bn = nn.BatchNorm2d(in_channel * kernels_per_layer)
76
+ self.selu = nn.SELU()
77
+ self.PWConv = nn.Conv2d(
78
+ in_channel * kernels_per_layer, out_channel, kernel_size=1, bias=False
79
+ )
80
+ self.bn2 = nn.BatchNorm2d(out_channel)
81
+
82
+ def forward(self, x):
83
+ x = self.DWConv(x)
84
+ x = self.selu(self.bn(x))
85
+ out = self.PWConv(x)
86
+ out = self.selu(self.bn2(out))
87
+
88
+ return out
carvekit/ml/arch/tracerb7/effi_utils.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Original author: lukemelas (github username)
3
+ Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
4
+ With adjustments and added comments by workingcoder (github username).
5
+ License: Apache License 2.0
6
+ Reimplemented: Min Seok Lee and Wooseok Shin
7
+ """
8
+
9
+ import collections
10
+ import re
11
+ from functools import partial
12
+
13
+ import math
14
+ import torch
15
+ from torch import nn
16
+ from torch.nn import functional as F
17
+
18
+ # Parameters for the entire model (stem, all blocks, and head)
19
+ GlobalParams = collections.namedtuple(
20
+ "GlobalParams",
21
+ [
22
+ "width_coefficient",
23
+ "depth_coefficient",
24
+ "image_size",
25
+ "dropout_rate",
26
+ "num_classes",
27
+ "batch_norm_momentum",
28
+ "batch_norm_epsilon",
29
+ "drop_connect_rate",
30
+ "depth_divisor",
31
+ "min_depth",
32
+ "include_top",
33
+ ],
34
+ )
35
+
36
+ # Parameters for an individual model block
37
+ BlockArgs = collections.namedtuple(
38
+ "BlockArgs",
39
+ [
40
+ "num_repeat",
41
+ "kernel_size",
42
+ "stride",
43
+ "expand_ratio",
44
+ "input_filters",
45
+ "output_filters",
46
+ "se_ratio",
47
+ "id_skip",
48
+ ],
49
+ )
50
+
51
+ # Set GlobalParams and BlockArgs's defaults
52
+ GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
53
+ BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
54
+
55
+
56
+ # An ordinary implementation of Swish function
57
+ class Swish(nn.Module):
58
+ def forward(self, x):
59
+ return x * torch.sigmoid(x)
60
+
61
+
62
+ # A memory-efficient implementation of Swish function
63
+ class SwishImplementation(torch.autograd.Function):
64
+ @staticmethod
65
+ def forward(ctx, i):
66
+ result = i * torch.sigmoid(i)
67
+ ctx.save_for_backward(i)
68
+ return result
69
+
70
+ @staticmethod
71
+ def backward(ctx, grad_output):
72
+ i = ctx.saved_tensors[0]
73
+ sigmoid_i = torch.sigmoid(i)
74
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
75
+
76
+
77
+ class MemoryEfficientSwish(nn.Module):
78
+ def forward(self, x):
79
+ return SwishImplementation.apply(x)
80
+
81
+
82
+ def round_filters(filters, global_params):
83
+ """Calculate and round number of filters based on width multiplier.
84
+ Use width_coefficient, depth_divisor and min_depth of global_params.
85
+
86
+ Args:
87
+ filters (int): Filters number to be calculated.
88
+ global_params (namedtuple): Global params of the model.
89
+
90
+ Returns:
91
+ new_filters: New filters number after calculating.
92
+ """
93
+ multiplier = global_params.width_coefficient
94
+ if not multiplier:
95
+ return filters
96
+ divisor = global_params.depth_divisor
97
+ min_depth = global_params.min_depth
98
+ filters *= multiplier
99
+ min_depth = min_depth or divisor # pay attention to this line when using min_depth
100
+ # follow the formula transferred from official TensorFlow implementation
101
+ new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
102
+ if new_filters < 0.9 * filters: # prevent rounding by more than 10%
103
+ new_filters += divisor
104
+ return int(new_filters)
105
+
106
+
107
+ def round_repeats(repeats, global_params):
108
+ """Calculate module's repeat number of a block based on depth multiplier.
109
+ Use depth_coefficient of global_params.
110
+
111
+ Args:
112
+ repeats (int): num_repeat to be calculated.
113
+ global_params (namedtuple): Global params of the model.
114
+
115
+ Returns:
116
+ new repeat: New repeat number after calculating.
117
+ """
118
+ multiplier = global_params.depth_coefficient
119
+ if not multiplier:
120
+ return repeats
121
+ # follow the formula transferred from official TensorFlow implementation
122
+ return int(math.ceil(multiplier * repeats))
123
+
124
+
125
+ def drop_connect(inputs, p, training):
126
+ """Drop connect.
127
+
128
+ Args:
129
+ input (tensor: BCWH): Input of this structure.
130
+ p (float: 0.0~1.0): Probability of drop connection.
131
+ training (bool): The running mode.
132
+
133
+ Returns:
134
+ output: Output after drop connection.
135
+ """
136
+ assert 0 <= p <= 1, "p must be in range of [0,1]"
137
+
138
+ if not training:
139
+ return inputs
140
+
141
+ batch_size = inputs.shape[0]
142
+ keep_prob = 1 - p
143
+
144
+ # generate binary_tensor mask according to probability (p for 0, 1-p for 1)
145
+ random_tensor = keep_prob
146
+ random_tensor += torch.rand(
147
+ [batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device
148
+ )
149
+ binary_tensor = torch.floor(random_tensor)
150
+
151
+ output = inputs / keep_prob * binary_tensor
152
+ return output
153
+
154
+
155
+ def get_width_and_height_from_size(x):
156
+ """Obtain height and width from x.
157
+
158
+ Args:
159
+ x (int, tuple or list): Data size.
160
+
161
+ Returns:
162
+ size: A tuple or list (H,W).
163
+ """
164
+ if isinstance(x, int):
165
+ return x, x
166
+ if isinstance(x, list) or isinstance(x, tuple):
167
+ return x
168
+ else:
169
+ raise TypeError()
170
+
171
+
172
+ def calculate_output_image_size(input_image_size, stride):
173
+ """Calculates the output image size when using Conv2dSamePadding with a stride.
174
+ Necessary for static padding. Thanks to mannatsingh for pointing this out.
175
+
176
+ Args:
177
+ input_image_size (int, tuple or list): Size of input image.
178
+ stride (int, tuple or list): Conv2d operation's stride.
179
+
180
+ Returns:
181
+ output_image_size: A list [H,W].
182
+ """
183
+ if input_image_size is None:
184
+ return None
185
+ image_height, image_width = get_width_and_height_from_size(input_image_size)
186
+ stride = stride if isinstance(stride, int) else stride[0]
187
+ image_height = int(math.ceil(image_height / stride))
188
+ image_width = int(math.ceil(image_width / stride))
189
+ return [image_height, image_width]
190
+
191
+
192
+ # Note:
193
+ # The following 'SamePadding' functions make output size equal ceil(input size/stride).
194
+ # Only when stride equals 1, can the output size be the same as input size.
195
+ # Don't be confused by their function names ! ! !
196
+
197
+
198
+ def get_same_padding_conv2d(image_size=None):
199
+ """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
200
+ Static padding is necessary for ONNX exporting of models.
201
+
202
+ Args:
203
+ image_size (int or tuple): Size of the image.
204
+
205
+ Returns:
206
+ Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
207
+ """
208
+ if image_size is None:
209
+ return Conv2dDynamicSamePadding
210
+ else:
211
+ return partial(Conv2dStaticSamePadding, image_size=image_size)
212
+
213
+
214
+ class Conv2dDynamicSamePadding(nn.Conv2d):
215
+ """2D Convolutions like TensorFlow, for a dynamic image size.
216
+ The padding is operated in forward function by calculating dynamically.
217
+ """
218
+
219
+ # Tips for 'SAME' mode padding.
220
+ # Given the following:
221
+ # i: width or height
222
+ # s: stride
223
+ # k: kernel size
224
+ # d: dilation
225
+ # p: padding
226
+ # Output after Conv2d:
227
+ # o = floor((i+p-((k-1)*d+1))/s+1)
228
+ # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
229
+ # => p = (i-1)*s+((k-1)*d+1)-i
230
+
231
+ def __init__(
232
+ self,
233
+ in_channels,
234
+ out_channels,
235
+ kernel_size,
236
+ stride=1,
237
+ dilation=1,
238
+ groups=1,
239
+ bias=True,
240
+ ):
241
+ super().__init__(
242
+ in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias
243
+ )
244
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
245
+
246
+ def forward(self, x):
247
+ ih, iw = x.size()[-2:]
248
+ kh, kw = self.weight.size()[-2:]
249
+ sh, sw = self.stride
250
+ oh, ow = math.ceil(ih / sh), math.ceil(
251
+ iw / sw
252
+ ) # change the output size according to stride ! ! !
253
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
254
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
255
+ if pad_h > 0 or pad_w > 0:
256
+ x = F.pad(
257
+ x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
258
+ )
259
+ return F.conv2d(
260
+ x,
261
+ self.weight,
262
+ self.bias,
263
+ self.stride,
264
+ self.padding,
265
+ self.dilation,
266
+ self.groups,
267
+ )
268
+
269
+
270
+ class Conv2dStaticSamePadding(nn.Conv2d):
271
+ """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
272
+ The padding mudule is calculated in construction function, then used in forward.
273
+ """
274
+
275
+ # With the same calculation as Conv2dDynamicSamePadding
276
+
277
+ def __init__(
278
+ self,
279
+ in_channels,
280
+ out_channels,
281
+ kernel_size,
282
+ stride=1,
283
+ image_size=None,
284
+ **kwargs
285
+ ):
286
+ super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
287
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
288
+
289
+ # Calculate padding based on image size and save it
290
+ assert image_size is not None
291
+ ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
292
+ kh, kw = self.weight.size()[-2:]
293
+ sh, sw = self.stride
294
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
295
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
296
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
297
+ if pad_h > 0 or pad_w > 0:
298
+ self.static_padding = nn.ZeroPad2d(
299
+ (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
300
+ )
301
+ else:
302
+ self.static_padding = nn.Identity()
303
+
304
+ def forward(self, x):
305
+ x = self.static_padding(x)
306
+ x = F.conv2d(
307
+ x,
308
+ self.weight,
309
+ self.bias,
310
+ self.stride,
311
+ self.padding,
312
+ self.dilation,
313
+ self.groups,
314
+ )
315
+ return x
316
+
317
+
318
+ def get_same_padding_maxPool2d(image_size=None):
319
+ """Chooses static padding if you have specified an image size, and dynamic padding otherwise.
320
+ Static padding is necessary for ONNX exporting of models.
321
+
322
+ Args:
323
+ image_size (int or tuple): Size of the image.
324
+
325
+ Returns:
326
+ MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
327
+ """
328
+ if image_size is None:
329
+ return MaxPool2dDynamicSamePadding
330
+ else:
331
+ return partial(MaxPool2dStaticSamePadding, image_size=image_size)
332
+
333
+
334
+ class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
335
+ """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
336
+ The padding is operated in forward function by calculating dynamically.
337
+ """
338
+
339
+ def __init__(
340
+ self,
341
+ kernel_size,
342
+ stride,
343
+ padding=0,
344
+ dilation=1,
345
+ return_indices=False,
346
+ ceil_mode=False,
347
+ ):
348
+ super().__init__(
349
+ kernel_size, stride, padding, dilation, return_indices, ceil_mode
350
+ )
351
+ self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
352
+ self.kernel_size = (
353
+ [self.kernel_size] * 2
354
+ if isinstance(self.kernel_size, int)
355
+ else self.kernel_size
356
+ )
357
+ self.dilation = (
358
+ [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
359
+ )
360
+
361
+ def forward(self, x):
362
+ ih, iw = x.size()[-2:]
363
+ kh, kw = self.kernel_size
364
+ sh, sw = self.stride
365
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
366
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
367
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
368
+ if pad_h > 0 or pad_w > 0:
369
+ x = F.pad(
370
+ x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
371
+ )
372
+ return F.max_pool2d(
373
+ x,
374
+ self.kernel_size,
375
+ self.stride,
376
+ self.padding,
377
+ self.dilation,
378
+ self.ceil_mode,
379
+ self.return_indices,
380
+ )
381
+
382
+
383
+ class MaxPool2dStaticSamePadding(nn.MaxPool2d):
384
+ """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
385
+ The padding mudule is calculated in construction function, then used in forward.
386
+ """
387
+
388
+ def __init__(self, kernel_size, stride, image_size=None, **kwargs):
389
+ super().__init__(kernel_size, stride, **kwargs)
390
+ self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
391
+ self.kernel_size = (
392
+ [self.kernel_size] * 2
393
+ if isinstance(self.kernel_size, int)
394
+ else self.kernel_size
395
+ )
396
+ self.dilation = (
397
+ [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
398
+ )
399
+
400
+ # Calculate padding based on image size and save it
401
+ assert image_size is not None
402
+ ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
403
+ kh, kw = self.kernel_size
404
+ sh, sw = self.stride
405
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
406
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
407
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
408
+ if pad_h > 0 or pad_w > 0:
409
+ self.static_padding = nn.ZeroPad2d(
410
+ (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
411
+ )
412
+ else:
413
+ self.static_padding = nn.Identity()
414
+
415
+ def forward(self, x):
416
+ x = self.static_padding(x)
417
+ x = F.max_pool2d(
418
+ x,
419
+ self.kernel_size,
420
+ self.stride,
421
+ self.padding,
422
+ self.dilation,
423
+ self.ceil_mode,
424
+ self.return_indices,
425
+ )
426
+ return x
427
+
428
+
429
+ class BlockDecoder(object):
430
+ """Block Decoder for readability,
431
+ straight from the official TensorFlow repository.
432
+ """
433
+
434
+ @staticmethod
435
+ def _decode_block_string(block_string):
436
+ """Get a block through a string notation of arguments.
437
+
438
+ Args:
439
+ block_string (str): A string notation of arguments.
440
+ Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
441
+
442
+ Returns:
443
+ BlockArgs: The namedtuple defined at the top of this file.
444
+ """
445
+ assert isinstance(block_string, str)
446
+
447
+ ops = block_string.split("_")
448
+ options = {}
449
+ for op in ops:
450
+ splits = re.split(r"(\d.*)", op)
451
+ if len(splits) >= 2:
452
+ key, value = splits[:2]
453
+ options[key] = value
454
+
455
+ # Check stride
456
+ assert ("s" in options and len(options["s"]) == 1) or (
457
+ len(options["s"]) == 2 and options["s"][0] == options["s"][1]
458
+ )
459
+
460
+ return BlockArgs(
461
+ num_repeat=int(options["r"]),
462
+ kernel_size=int(options["k"]),
463
+ stride=[int(options["s"][0])],
464
+ expand_ratio=int(options["e"]),
465
+ input_filters=int(options["i"]),
466
+ output_filters=int(options["o"]),
467
+ se_ratio=float(options["se"]) if "se" in options else None,
468
+ id_skip=("noskip" not in block_string),
469
+ )
470
+
471
+ @staticmethod
472
+ def _encode_block_string(block):
473
+ """Encode a block to a string.
474
+
475
+ Args:
476
+ block (namedtuple): A BlockArgs type argument.
477
+
478
+ Returns:
479
+ block_string: A String form of BlockArgs.
480
+ """
481
+ args = [
482
+ "r%d" % block.num_repeat,
483
+ "k%d" % block.kernel_size,
484
+ "s%d%d" % (block.strides[0], block.strides[1]),
485
+ "e%s" % block.expand_ratio,
486
+ "i%d" % block.input_filters,
487
+ "o%d" % block.output_filters,
488
+ ]
489
+ if 0 < block.se_ratio <= 1:
490
+ args.append("se%s" % block.se_ratio)
491
+ if block.id_skip is False:
492
+ args.append("noskip")
493
+ return "_".join(args)
494
+
495
+ @staticmethod
496
+ def decode(string_list):
497
+ """Decode a list of string notations to specify blocks inside the network.
498
+
499
+ Args:
500
+ string_list (list[str]): A list of strings, each string is a notation of block.
501
+
502
+ Returns:
503
+ blocks_args: A list of BlockArgs namedtuples of block args.
504
+ """
505
+ assert isinstance(string_list, list)
506
+ blocks_args = []
507
+ for block_string in string_list:
508
+ blocks_args.append(BlockDecoder._decode_block_string(block_string))
509
+ return blocks_args
510
+
511
+ @staticmethod
512
+ def encode(blocks_args):
513
+ """Encode a list of BlockArgs to a list of strings.
514
+
515
+ Args:
516
+ blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
517
+
518
+ Returns:
519
+ block_strings: A list of strings, each string is a notation of block.
520
+ """
521
+ block_strings = []
522
+ for block in blocks_args:
523
+ block_strings.append(BlockDecoder._encode_block_string(block))
524
+ return block_strings
525
+
526
+
527
+ def create_block_args(
528
+ width_coefficient=None,
529
+ depth_coefficient=None,
530
+ image_size=None,
531
+ dropout_rate=0.2,
532
+ drop_connect_rate=0.2,
533
+ num_classes=1000,
534
+ include_top=True,
535
+ ):
536
+ """Create BlockArgs and GlobalParams for efficientnet model.
537
+
538
+ Args:
539
+ width_coefficient (float)
540
+ depth_coefficient (float)
541
+ image_size (int)
542
+ dropout_rate (float)
543
+ drop_connect_rate (float)
544
+ num_classes (int)
545
+
546
+ Meaning as the name suggests.
547
+
548
+ Returns:
549
+ blocks_args, global_params.
550
+ """
551
+
552
+ # Blocks args for the whole model(efficientnet-b0 by default)
553
+ # It will be modified in the construction of EfficientNet Class according to model
554
+ blocks_args = [
555
+ "r1_k3_s11_e1_i32_o16_se0.25",
556
+ "r2_k3_s22_e6_i16_o24_se0.25",
557
+ "r2_k5_s22_e6_i24_o40_se0.25",
558
+ "r3_k3_s22_e6_i40_o80_se0.25",
559
+ "r3_k5_s11_e6_i80_o112_se0.25",
560
+ "r4_k5_s22_e6_i112_o192_se0.25",
561
+ "r1_k3_s11_e6_i192_o320_se0.25",
562
+ ]
563
+ blocks_args = BlockDecoder.decode(blocks_args)
564
+
565
+ global_params = GlobalParams(
566
+ width_coefficient=width_coefficient,
567
+ depth_coefficient=depth_coefficient,
568
+ image_size=image_size,
569
+ dropout_rate=dropout_rate,
570
+ num_classes=num_classes,
571
+ batch_norm_momentum=0.99,
572
+ batch_norm_epsilon=1e-3,
573
+ drop_connect_rate=drop_connect_rate,
574
+ depth_divisor=8,
575
+ min_depth=None,
576
+ include_top=include_top,
577
+ )
578
+
579
+ return blocks_args, global_params
carvekit/ml/arch/tracerb7/efficientnet.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/lukemelas/EfficientNet-PyTorch
3
+ Modified by Min Seok Lee, Wooseok Shin, Nikita Selin
4
+ License: Apache License 2.0
5
+ Changes:
6
+ - Added support for extracting edge features
7
+ - Added support for extracting object features at different levels
8
+ - Refactored the code
9
+ """
10
+ from typing import Any, List
11
+
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+
16
+ from carvekit.ml.arch.tracerb7.effi_utils import (
17
+ get_same_padding_conv2d,
18
+ calculate_output_image_size,
19
+ MemoryEfficientSwish,
20
+ drop_connect,
21
+ round_filters,
22
+ round_repeats,
23
+ Swish,
24
+ create_block_args,
25
+ )
26
+
27
+
28
+ class MBConvBlock(nn.Module):
29
+ """Mobile Inverted Residual Bottleneck Block.
30
+
31
+ Args:
32
+ block_args (namedtuple): BlockArgs, defined in utils.py.
33
+ global_params (namedtuple): GlobalParam, defined in utils.py.
34
+ image_size (tuple or list): [image_height, image_width].
35
+
36
+ References:
37
+ [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
38
+ [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
39
+ [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
40
+ """
41
+
42
+ def __init__(self, block_args, global_params, image_size=None):
43
+ super().__init__()
44
+ self._block_args = block_args
45
+ self._bn_mom = (
46
+ 1 - global_params.batch_norm_momentum
47
+ ) # pytorch's difference from tensorflow
48
+ self._bn_eps = global_params.batch_norm_epsilon
49
+ self.has_se = (self._block_args.se_ratio is not None) and (
50
+ 0 < self._block_args.se_ratio <= 1
51
+ )
52
+ self.id_skip = (
53
+ block_args.id_skip
54
+ ) # whether to use skip connection and drop connect
55
+
56
+ # Expansion phase (Inverted Bottleneck)
57
+ inp = self._block_args.input_filters # number of input channels
58
+ oup = (
59
+ self._block_args.input_filters * self._block_args.expand_ratio
60
+ ) # number of output channels
61
+ if self._block_args.expand_ratio != 1:
62
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
63
+ self._expand_conv = Conv2d(
64
+ in_channels=inp, out_channels=oup, kernel_size=1, bias=False
65
+ )
66
+ self._bn0 = nn.BatchNorm2d(
67
+ num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
68
+ )
69
+ # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
70
+
71
+ # Depthwise convolution phase
72
+ k = self._block_args.kernel_size
73
+ s = self._block_args.stride
74
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
75
+ self._depthwise_conv = Conv2d(
76
+ in_channels=oup,
77
+ out_channels=oup,
78
+ groups=oup, # groups makes it depthwise
79
+ kernel_size=k,
80
+ stride=s,
81
+ bias=False,
82
+ )
83
+ self._bn1 = nn.BatchNorm2d(
84
+ num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
85
+ )
86
+ image_size = calculate_output_image_size(image_size, s)
87
+
88
+ # Squeeze and Excitation layer, if desired
89
+ if self.has_se:
90
+ Conv2d = get_same_padding_conv2d(image_size=(1, 1))
91
+ num_squeezed_channels = max(
92
+ 1, int(self._block_args.input_filters * self._block_args.se_ratio)
93
+ )
94
+ self._se_reduce = Conv2d(
95
+ in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1
96
+ )
97
+ self._se_expand = Conv2d(
98
+ in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1
99
+ )
100
+
101
+ # Pointwise convolution phase
102
+ final_oup = self._block_args.output_filters
103
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
104
+ self._project_conv = Conv2d(
105
+ in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False
106
+ )
107
+ self._bn2 = nn.BatchNorm2d(
108
+ num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps
109
+ )
110
+ self._swish = MemoryEfficientSwish()
111
+
112
+ def forward(self, inputs, drop_connect_rate=None):
113
+ """MBConvBlock's forward function.
114
+
115
+ Args:
116
+ inputs (tensor): Input tensor.
117
+ drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
118
+
119
+ Returns:
120
+ Output of this block after processing.
121
+ """
122
+
123
+ # Expansion and Depthwise Convolution
124
+ x = inputs
125
+ if self._block_args.expand_ratio != 1:
126
+ x = self._expand_conv(inputs)
127
+ x = self._bn0(x)
128
+ x = self._swish(x)
129
+
130
+ x = self._depthwise_conv(x)
131
+ x = self._bn1(x)
132
+ x = self._swish(x)
133
+
134
+ # Squeeze and Excitation
135
+ if self.has_se:
136
+ x_squeezed = F.adaptive_avg_pool2d(x, 1)
137
+ x_squeezed = self._se_reduce(x_squeezed)
138
+ x_squeezed = self._swish(x_squeezed)
139
+ x_squeezed = self._se_expand(x_squeezed)
140
+ x = torch.sigmoid(x_squeezed) * x
141
+
142
+ # Pointwise Convolution
143
+ x = self._project_conv(x)
144
+ x = self._bn2(x)
145
+
146
+ # Skip connection and drop connect
147
+ input_filters, output_filters = (
148
+ self._block_args.input_filters,
149
+ self._block_args.output_filters,
150
+ )
151
+ if (
152
+ self.id_skip
153
+ and self._block_args.stride == 1
154
+ and input_filters == output_filters
155
+ ):
156
+ # The combination of skip connection and drop connect brings about stochastic depth.
157
+ if drop_connect_rate:
158
+ x = drop_connect(x, p=drop_connect_rate, training=self.training)
159
+ x = x + inputs # skip connection
160
+ return x
161
+
162
+ def set_swish(self, memory_efficient=True):
163
+ """Sets swish function as memory efficient (for training) or standard (for export).
164
+
165
+ Args:
166
+ memory_efficient (bool): Whether to use memory-efficient version of swish.
167
+ """
168
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
169
+
170
+
171
+ class EfficientNet(nn.Module):
172
+ def __init__(self, blocks_args=None, global_params=None):
173
+ super().__init__()
174
+ assert isinstance(blocks_args, list), "blocks_args should be a list"
175
+ assert len(blocks_args) > 0, "block args must be greater than 0"
176
+ self._global_params = global_params
177
+ self._blocks_args = blocks_args
178
+
179
+ # Batch norm parameters
180
+ bn_mom = 1 - self._global_params.batch_norm_momentum
181
+ bn_eps = self._global_params.batch_norm_epsilon
182
+
183
+ # Get stem static or dynamic convolution depending on image size
184
+ image_size = global_params.image_size
185
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
186
+
187
+ # Stem
188
+ in_channels = 3 # rgb
189
+ out_channels = round_filters(
190
+ 32, self._global_params
191
+ ) # number of output channels
192
+ self._conv_stem = Conv2d(
193
+ in_channels, out_channels, kernel_size=3, stride=2, bias=False
194
+ )
195
+ self._bn0 = nn.BatchNorm2d(
196
+ num_features=out_channels, momentum=bn_mom, eps=bn_eps
197
+ )
198
+ image_size = calculate_output_image_size(image_size, 2)
199
+
200
+ # Build blocks
201
+ self._blocks = nn.ModuleList([])
202
+ for block_args in self._blocks_args:
203
+
204
+ # Update block input and output filters based on depth multiplier.
205
+ block_args = block_args._replace(
206
+ input_filters=round_filters(
207
+ block_args.input_filters, self._global_params
208
+ ),
209
+ output_filters=round_filters(
210
+ block_args.output_filters, self._global_params
211
+ ),
212
+ num_repeat=round_repeats(block_args.num_repeat, self._global_params),
213
+ )
214
+
215
+ # The first block needs to take care of stride and filter size increase.
216
+ self._blocks.append(
217
+ MBConvBlock(block_args, self._global_params, image_size=image_size)
218
+ )
219
+ image_size = calculate_output_image_size(image_size, block_args.stride)
220
+ if block_args.num_repeat > 1: # modify block_args to keep same output size
221
+ block_args = block_args._replace(
222
+ input_filters=block_args.output_filters, stride=1
223
+ )
224
+ for _ in range(block_args.num_repeat - 1):
225
+ self._blocks.append(
226
+ MBConvBlock(block_args, self._global_params, image_size=image_size)
227
+ )
228
+ # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
229
+
230
+ self._swish = MemoryEfficientSwish()
231
+
232
+ def set_swish(self, memory_efficient=True):
233
+ """Sets swish function as memory efficient (for training) or standard (for export).
234
+
235
+ Args:
236
+ memory_efficient (bool): Whether to use memory-efficient version of swish.
237
+
238
+ """
239
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
240
+ for block in self._blocks:
241
+ block.set_swish(memory_efficient)
242
+
243
+ def extract_endpoints(self, inputs):
244
+ endpoints = dict()
245
+
246
+ # Stem
247
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
248
+ prev_x = x
249
+
250
+ # Blocks
251
+ for idx, block in enumerate(self._blocks):
252
+ drop_connect_rate = self._global_params.drop_connect_rate
253
+ if drop_connect_rate:
254
+ drop_connect_rate *= float(idx) / len(
255
+ self._blocks
256
+ ) # scale drop connect_rate
257
+ x = block(x, drop_connect_rate=drop_connect_rate)
258
+ if prev_x.size(2) > x.size(2):
259
+ endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x
260
+ prev_x = x
261
+
262
+ # Head
263
+ x = self._swish(self._bn1(self._conv_head(x)))
264
+ endpoints["reduction_{}".format(len(endpoints) + 1)] = x
265
+
266
+ return endpoints
267
+
268
+ def _change_in_channels(self, in_channels):
269
+ """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
270
+
271
+ Args:
272
+ in_channels (int): Input data's channel number.
273
+ """
274
+ if in_channels != 3:
275
+ Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
276
+ out_channels = round_filters(32, self._global_params)
277
+ self._conv_stem = Conv2d(
278
+ in_channels, out_channels, kernel_size=3, stride=2, bias=False
279
+ )
280
+
281
+
282
+ class EfficientEncoderB7(EfficientNet):
283
+ def __init__(self):
284
+ super().__init__(
285
+ *create_block_args(
286
+ width_coefficient=2.0,
287
+ depth_coefficient=3.1,
288
+ dropout_rate=0.5,
289
+ image_size=600,
290
+ )
291
+ )
292
+ self._change_in_channels(3)
293
+ self.block_idx = [10, 17, 37, 54]
294
+ self.channels = [48, 80, 224, 640]
295
+
296
+ def initial_conv(self, inputs):
297
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
298
+ return x
299
+
300
+ def get_blocks(self, x, H, W, block_idx):
301
+ features = []
302
+ for idx, block in enumerate(self._blocks):
303
+ drop_connect_rate = self._global_params.drop_connect_rate
304
+ if drop_connect_rate:
305
+ drop_connect_rate *= float(idx) / len(
306
+ self._blocks
307
+ ) # scale drop connect_rate
308
+ x = block(x, drop_connect_rate=drop_connect_rate)
309
+ if idx == block_idx[0]:
310
+ features.append(x.clone())
311
+ if idx == block_idx[1]:
312
+ features.append(x.clone())
313
+ if idx == block_idx[2]:
314
+ features.append(x.clone())
315
+ if idx == block_idx[3]:
316
+ features.append(x.clone())
317
+
318
+ return features
319
+
320
+ def forward(self, inputs: torch.Tensor) -> List[Any]:
321
+ B, C, H, W = inputs.size()
322
+ x = self.initial_conv(inputs) # Prepare input for the backbone
323
+ return self.get_blocks(
324
+ x, H, W, block_idx=self.block_idx
325
+ ) # Get backbone features and edge maps
carvekit/ml/arch/tracerb7/tracer.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/Karel911/TRACER
3
+ Author: Min Seok Lee and Wooseok Shin
4
+ Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
5
+ License: Apache License 2.0
6
+ Changes:
7
+ - Refactored code
8
+ - Removed unused code
9
+ - Added comments
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from typing import List, Optional, Tuple
16
+
17
+ from torch import Tensor
18
+
19
+ from carvekit.ml.arch.tracerb7.efficientnet import EfficientEncoderB7
20
+ from carvekit.ml.arch.tracerb7.att_modules import (
21
+ RFB_Block,
22
+ aggregation,
23
+ ObjectAttention,
24
+ )
25
+
26
+
27
+ class TracerDecoder(nn.Module):
28
+ """Tracer Decoder"""
29
+
30
+ def __init__(
31
+ self,
32
+ encoder: EfficientEncoderB7,
33
+ features_channels: Optional[List[int]] = None,
34
+ rfb_channel: Optional[List[int]] = None,
35
+ ):
36
+ """
37
+ Initialize the tracer decoder.
38
+
39
+ Args:
40
+ encoder: The encoder to use.
41
+ features_channels: The channels of the backbone features at different stages. default: [48, 80, 224, 640]
42
+ rfb_channel: The channels of the RFB features. default: [32, 64, 128]
43
+ """
44
+ super().__init__()
45
+ if rfb_channel is None:
46
+ rfb_channel = [32, 64, 128]
47
+ if features_channels is None:
48
+ features_channels = [48, 80, 224, 640]
49
+ self.encoder = encoder
50
+ self.features_channels = features_channels
51
+
52
+ # Receptive Field Blocks
53
+ features_channels = rfb_channel
54
+ self.rfb2 = RFB_Block(self.features_channels[1], features_channels[0])
55
+ self.rfb3 = RFB_Block(self.features_channels[2], features_channels[1])
56
+ self.rfb4 = RFB_Block(self.features_channels[3], features_channels[2])
57
+
58
+ # Multi-level aggregation
59
+ self.agg = aggregation(features_channels)
60
+
61
+ # Object Attention
62
+ self.ObjectAttention2 = ObjectAttention(
63
+ channel=self.features_channels[1], kernel_size=3
64
+ )
65
+ self.ObjectAttention1 = ObjectAttention(
66
+ channel=self.features_channels[0], kernel_size=3
67
+ )
68
+
69
+ def forward(self, inputs: torch.Tensor) -> Tensor:
70
+ """
71
+ Forward pass of the tracer decoder.
72
+
73
+ Args:
74
+ inputs: Preprocessed images.
75
+
76
+ Returns:
77
+ Tensors of segmentation masks and mask of object edges.
78
+ """
79
+ features = self.encoder(inputs)
80
+ x3_rfb = self.rfb2(features[1])
81
+ x4_rfb = self.rfb3(features[2])
82
+ x5_rfb = self.rfb4(features[3])
83
+
84
+ D_0 = self.agg(x5_rfb, x4_rfb, x3_rfb)
85
+
86
+ ds_map0 = F.interpolate(D_0, scale_factor=8, mode="bilinear")
87
+
88
+ D_1 = self.ObjectAttention2(D_0, features[1])
89
+ ds_map1 = F.interpolate(D_1, scale_factor=8, mode="bilinear")
90
+
91
+ ds_map = F.interpolate(D_1, scale_factor=2, mode="bilinear")
92
+ D_2 = self.ObjectAttention1(ds_map, features[0])
93
+ ds_map2 = F.interpolate(D_2, scale_factor=4, mode="bilinear")
94
+
95
+ final_map = (ds_map2 + ds_map1 + ds_map0) / 3
96
+
97
+ return torch.sigmoid(final_map)
carvekit/ml/arch/u2net/__init__.py ADDED
File without changes
carvekit/ml/arch/u2net/u2net.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
3
+ Source url: https://github.com/xuebinqin/U-2-Net
4
+ License: Apache License 2.0
5
+ """
6
+ from typing import Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ import math
12
+
13
+ __all__ = ["U2NETArchitecture"]
14
+
15
+
16
+ def _upsample_like(x, size):
17
+ return nn.Upsample(size=size, mode="bilinear", align_corners=False)(x)
18
+
19
+
20
+ def _size_map(x, height):
21
+ # {height: size} for Upsample
22
+ size = list(x.shape[-2:])
23
+ sizes = {}
24
+ for h in range(1, height):
25
+ sizes[h] = size
26
+ size = [math.ceil(w / 2) for w in size]
27
+ return sizes
28
+
29
+
30
+ class REBNCONV(nn.Module):
31
+ def __init__(self, in_ch=3, out_ch=3, dilate=1):
32
+ super(REBNCONV, self).__init__()
33
+
34
+ self.conv_s1 = nn.Conv2d(
35
+ in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate
36
+ )
37
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
38
+ self.relu_s1 = nn.ReLU(inplace=True)
39
+
40
+ def forward(self, x):
41
+ return self.relu_s1(self.bn_s1(self.conv_s1(x)))
42
+
43
+
44
+ class RSU(nn.Module):
45
+ def __init__(self, name, height, in_ch, mid_ch, out_ch, dilated=False):
46
+ super(RSU, self).__init__()
47
+ self.name = name
48
+ self.height = height
49
+ self.dilated = dilated
50
+ self._make_layers(height, in_ch, mid_ch, out_ch, dilated)
51
+
52
+ def forward(self, x):
53
+ sizes = _size_map(x, self.height)
54
+ x = self.rebnconvin(x)
55
+
56
+ # U-Net like symmetric encoder-decoder structure
57
+ def unet(x, height=1):
58
+ if height < self.height:
59
+ x1 = getattr(self, f"rebnconv{height}")(x)
60
+ if not self.dilated and height < self.height - 1:
61
+ x2 = unet(getattr(self, "downsample")(x1), height + 1)
62
+ else:
63
+ x2 = unet(x1, height + 1)
64
+
65
+ x = getattr(self, f"rebnconv{height}d")(torch.cat((x2, x1), 1))
66
+ return (
67
+ _upsample_like(x, sizes[height - 1])
68
+ if not self.dilated and height > 1
69
+ else x
70
+ )
71
+ else:
72
+ return getattr(self, f"rebnconv{height}")(x)
73
+
74
+ return x + unet(x)
75
+
76
+ def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False):
77
+ self.add_module("rebnconvin", REBNCONV(in_ch, out_ch))
78
+ self.add_module("downsample", nn.MaxPool2d(2, stride=2, ceil_mode=True))
79
+
80
+ self.add_module("rebnconv1", REBNCONV(out_ch, mid_ch))
81
+ self.add_module("rebnconv1d", REBNCONV(mid_ch * 2, out_ch))
82
+
83
+ for i in range(2, height):
84
+ dilate = 1 if not dilated else 2 ** (i - 1)
85
+ self.add_module(f"rebnconv{i}", REBNCONV(mid_ch, mid_ch, dilate=dilate))
86
+ self.add_module(
87
+ f"rebnconv{i}d", REBNCONV(mid_ch * 2, mid_ch, dilate=dilate)
88
+ )
89
+
90
+ dilate = 2 if not dilated else 2 ** (height - 1)
91
+ self.add_module(f"rebnconv{height}", REBNCONV(mid_ch, mid_ch, dilate=dilate))
92
+
93
+
94
+ class U2NETArchitecture(nn.Module):
95
+ def __init__(self, cfg_type: Union[dict, str] = "full", out_ch: int = 1):
96
+ super(U2NETArchitecture, self).__init__()
97
+ if isinstance(cfg_type, str):
98
+ if cfg_type == "full":
99
+ layers_cfgs = {
100
+ # cfgs for building RSUs and sides
101
+ # {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
102
+ "stage1": ["En_1", (7, 3, 32, 64), -1],
103
+ "stage2": ["En_2", (6, 64, 32, 128), -1],
104
+ "stage3": ["En_3", (5, 128, 64, 256), -1],
105
+ "stage4": ["En_4", (4, 256, 128, 512), -1],
106
+ "stage5": ["En_5", (4, 512, 256, 512, True), -1],
107
+ "stage6": ["En_6", (4, 512, 256, 512, True), 512],
108
+ "stage5d": ["De_5", (4, 1024, 256, 512, True), 512],
109
+ "stage4d": ["De_4", (4, 1024, 128, 256), 256],
110
+ "stage3d": ["De_3", (5, 512, 64, 128), 128],
111
+ "stage2d": ["De_2", (6, 256, 32, 64), 64],
112
+ "stage1d": ["De_1", (7, 128, 16, 64), 64],
113
+ }
114
+ else:
115
+ raise ValueError("Unknown U^2-Net architecture conf. name")
116
+ elif isinstance(cfg_type, dict):
117
+ layers_cfgs = cfg_type
118
+ else:
119
+ raise ValueError("Unknown U^2-Net architecture conf. type")
120
+ self.out_ch = out_ch
121
+ self._make_layers(layers_cfgs)
122
+
123
+ def forward(self, x):
124
+ sizes = _size_map(x, self.height)
125
+ maps = [] # storage for maps
126
+
127
+ # side saliency map
128
+ def unet(x, height=1):
129
+ if height < 6:
130
+ x1 = getattr(self, f"stage{height}")(x)
131
+ x2 = unet(getattr(self, "downsample")(x1), height + 1)
132
+ x = getattr(self, f"stage{height}d")(torch.cat((x2, x1), 1))
133
+ side(x, height)
134
+ return _upsample_like(x, sizes[height - 1]) if height > 1 else x
135
+ else:
136
+ x = getattr(self, f"stage{height}")(x)
137
+ side(x, height)
138
+ return _upsample_like(x, sizes[height - 1])
139
+
140
+ def side(x, h):
141
+ # side output saliency map (before sigmoid)
142
+ x = getattr(self, f"side{h}")(x)
143
+ x = _upsample_like(x, sizes[1])
144
+ maps.append(x)
145
+
146
+ def fuse():
147
+ # fuse saliency probability maps
148
+ maps.reverse()
149
+ x = torch.cat(maps, 1)
150
+ x = getattr(self, "outconv")(x)
151
+ maps.insert(0, x)
152
+ return [torch.sigmoid(x) for x in maps]
153
+
154
+ unet(x)
155
+ maps = fuse()
156
+ return maps
157
+
158
+ def _make_layers(self, cfgs):
159
+ self.height = int((len(cfgs) + 1) / 2)
160
+ self.add_module("downsample", nn.MaxPool2d(2, stride=2, ceil_mode=True))
161
+ for k, v in cfgs.items():
162
+ # build rsu block
163
+ self.add_module(k, RSU(v[0], *v[1]))
164
+ if v[2] > 0:
165
+ # build side layer
166
+ self.add_module(
167
+ f"side{v[0][-1]}", nn.Conv2d(v[2], self.out_ch, 3, padding=1)
168
+ )
169
+ # build fuse layer
170
+ self.add_module(
171
+ "outconv", nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1)
172
+ )
carvekit/ml/files/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ carvekit_dir = Path.home().joinpath(".cache/carvekit")
4
+
5
+ carvekit_dir.mkdir(parents=True, exist_ok=True)
6
+
7
+ checkpoints_dir = carvekit_dir.joinpath("checkpoints")
carvekit/ml/files/models_loc.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ import pathlib
7
+ from carvekit.ml.files import checkpoints_dir
8
+ from carvekit.utils.download_models import downloader
9
+
10
+
11
+ def u2net_full_pretrained() -> pathlib.Path:
12
+ """Returns u2net pretrained model location
13
+
14
+ Returns:
15
+ pathlib.Path to model location
16
+ """
17
+ return downloader("u2net.pth")
18
+
19
+
20
+ def basnet_pretrained() -> pathlib.Path:
21
+ """Returns basnet pretrained model location
22
+
23
+ Returns:
24
+ pathlib.Path to model location
25
+ """
26
+ return downloader("basnet.pth")
27
+
28
+
29
+ def deeplab_pretrained() -> pathlib.Path:
30
+ """Returns basnet pretrained model location
31
+
32
+ Returns:
33
+ pathlib.Path to model location
34
+ """
35
+ return downloader("deeplab.pth")
36
+
37
+
38
+ def fba_pretrained() -> pathlib.Path:
39
+ """Returns basnet pretrained model location
40
+
41
+ Returns:
42
+ pathlib.Path to model location
43
+ """
44
+ return downloader("fba_matting.pth")
45
+
46
+
47
+ def tracer_b7_pretrained() -> pathlib.Path:
48
+ """Returns TRACER with EfficientNet v1 b7 encoder pretrained model location
49
+
50
+ Returns:
51
+ pathlib.Path to model location
52
+ """
53
+ return downloader("tracer_b7.pth")
54
+
55
+
56
+ def tracer_hair_pretrained() -> pathlib.Path:
57
+ """Returns TRACER with EfficientNet v1 b7 encoder model for hair segmentation location
58
+
59
+ Returns:
60
+ pathlib.Path to model location
61
+ """
62
+ return downloader("tracer_hair.pth")
63
+
64
+
65
+ def download_all():
66
+ u2net_full_pretrained()
67
+ fba_pretrained()
68
+ deeplab_pretrained()
69
+ basnet_pretrained()
70
+ tracer_b7_pretrained()
carvekit/ml/wrap/__init__.py ADDED
File without changes
carvekit/ml/wrap/basnet.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ import pathlib
7
+ from typing import Union, List
8
+
9
+ import PIL
10
+ import numpy as np
11
+ import torch
12
+ from PIL import Image
13
+
14
+ from carvekit.ml.arch.basnet.basnet import BASNet
15
+ from carvekit.ml.files.models_loc import basnet_pretrained
16
+ from carvekit.utils.image_utils import convert_image, load_image
17
+ from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
18
+
19
+ __all__ = ["BASNET"]
20
+
21
+
22
+ class BASNET(BASNet):
23
+ """BASNet model interface"""
24
+
25
+ def __init__(
26
+ self,
27
+ device="cpu",
28
+ input_image_size: Union[List[int], int] = 320,
29
+ batch_size: int = 10,
30
+ load_pretrained: bool = True,
31
+ fp16: bool = False,
32
+ ):
33
+ """
34
+ Initialize the BASNET model
35
+
36
+ Args:
37
+ device: processing device
38
+ input_image_size: input image size
39
+ batch_size: the number of images that the neural network processes in one run
40
+ load_pretrained: loading pretrained model
41
+ fp16: use fp16 precision // not supported at this moment
42
+
43
+ """
44
+ super(BASNET, self).__init__(n_channels=3, n_classes=1)
45
+ self.device = device
46
+ self.batch_size = batch_size
47
+ if isinstance(input_image_size, list):
48
+ self.input_image_size = input_image_size[:2]
49
+ else:
50
+ self.input_image_size = (input_image_size, input_image_size)
51
+ self.to(device)
52
+ if load_pretrained:
53
+ self.load_state_dict(
54
+ torch.load(basnet_pretrained(), map_location=self.device)
55
+ )
56
+ self.eval()
57
+
58
+ def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor:
59
+ """
60
+ Transform input image to suitable data format for neural network
61
+
62
+ Args:
63
+ data: input image
64
+
65
+ Returns:
66
+ input for neural network
67
+
68
+ """
69
+ resized = data.resize(self.input_image_size)
70
+ # noinspection PyTypeChecker
71
+ resized_arr = np.array(resized, dtype=np.float64)
72
+ temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3))
73
+ if np.max(resized_arr) != 0:
74
+ resized_arr /= np.max(resized_arr)
75
+ temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229
76
+ temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224
77
+ temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225
78
+ temp_image = temp_image.transpose((2, 0, 1))
79
+ temp_image = np.expand_dims(temp_image, 0)
80
+ return torch.from_numpy(temp_image).type(torch.FloatTensor)
81
+
82
+ @staticmethod
83
+ def data_postprocessing(
84
+ data: torch.tensor, original_image: PIL.Image.Image
85
+ ) -> PIL.Image.Image:
86
+ """
87
+ Transforms output data from neural network to suitable data
88
+ format for using with other components of this framework.
89
+
90
+ Args:
91
+ data: output data from neural network
92
+ original_image: input image which was used for predicted data
93
+
94
+ Returns:
95
+ Segmentation mask as PIL Image instance
96
+
97
+ """
98
+ data = data.unsqueeze(0)
99
+ mask = data[:, 0, :, :]
100
+ ma = torch.max(mask) # Normalizes prediction
101
+ mi = torch.min(mask)
102
+ predict = ((mask - mi) / (ma - mi)).squeeze()
103
+ predict_np = predict.cpu().data.numpy() * 255
104
+ mask = Image.fromarray(predict_np).convert("L")
105
+ mask = mask.resize(original_image.size, resample=3)
106
+ return mask
107
+
108
+ def __call__(
109
+ self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
110
+ ) -> List[PIL.Image.Image]:
111
+ """
112
+ Passes input images through neural network and returns segmentation masks as PIL.Image.Image instances
113
+
114
+ Args:
115
+ images: input images
116
+
117
+ Returns:
118
+ segmentation masks as for input images, as PIL.Image.Image instances
119
+
120
+ """
121
+ collect_masks = []
122
+ for image_batch in batch_generator(images, self.batch_size):
123
+ images = thread_pool_processing(
124
+ lambda x: convert_image(load_image(x)), image_batch
125
+ )
126
+ batches = torch.vstack(
127
+ thread_pool_processing(self.data_preprocessing, images)
128
+ )
129
+ with torch.no_grad():
130
+ batches = batches.to(self.device)
131
+ masks, d2, d3, d4, d5, d6, d7, d8 = super(BASNET, self).__call__(
132
+ batches
133
+ )
134
+ masks_cpu = masks.cpu()
135
+ del d2, d3, d4, d5, d6, d7, d8, batches, masks
136
+ masks = thread_pool_processing(
137
+ lambda x: self.data_postprocessing(masks_cpu[x], images[x]),
138
+ range(len(images)),
139
+ )
140
+ collect_masks += masks
141
+ return collect_masks
carvekit/ml/wrap/deeplab_v3.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ import pathlib
7
+ from typing import List, Union
8
+
9
+ import PIL.Image
10
+ import torch
11
+ from PIL import Image
12
+ from torchvision import transforms
13
+ from torchvision.models.segmentation import deeplabv3_resnet101
14
+ from carvekit.ml.files.models_loc import deeplab_pretrained
15
+ from carvekit.utils.image_utils import convert_image, load_image
16
+ from carvekit.utils.models_utils import get_precision_autocast, cast_network
17
+ from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
18
+
19
+ __all__ = ["DeepLabV3"]
20
+
21
+
22
+ class DeepLabV3:
23
+ def __init__(
24
+ self,
25
+ device="cpu",
26
+ batch_size: int = 10,
27
+ input_image_size: Union[List[int], int] = 1024,
28
+ load_pretrained: bool = True,
29
+ fp16: bool = False,
30
+ ):
31
+ """
32
+ Initialize the DeepLabV3 model
33
+
34
+ Args:
35
+ device: processing device
36
+ input_image_size: input image size
37
+ batch_size: the number of images that the neural network processes in one run
38
+ load_pretrained: loading pretrained model
39
+ fp16: use half precision
40
+
41
+ """
42
+ self.device = device
43
+ self.batch_size = batch_size
44
+ self.network = deeplabv3_resnet101(
45
+ pretrained=False, pretrained_backbone=False, aux_loss=True
46
+ )
47
+ self.network.to(self.device)
48
+ if load_pretrained:
49
+ self.network.load_state_dict(
50
+ torch.load(deeplab_pretrained(), map_location=self.device)
51
+ )
52
+ if isinstance(input_image_size, list):
53
+ self.input_image_size = input_image_size[:2]
54
+ else:
55
+ self.input_image_size = (input_image_size, input_image_size)
56
+ self.network.eval()
57
+ self.fp16 = fp16
58
+ self.transform = transforms.Compose(
59
+ [
60
+ transforms.ToTensor(),
61
+ transforms.Normalize(
62
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
63
+ ),
64
+ ]
65
+ )
66
+
67
+ def to(self, device: str):
68
+ """
69
+ Moves neural network to specified processing device
70
+
71
+ Args:
72
+ device (:class:`torch.device`): the desired device.
73
+ Returns:
74
+ None
75
+
76
+ """
77
+ self.network.to(device)
78
+
79
+ def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor:
80
+ """
81
+ Transform input image to suitable data format for neural network
82
+
83
+ Args:
84
+ data: input image
85
+
86
+ Returns:
87
+ input for neural network
88
+
89
+ """
90
+ copy = data.copy()
91
+ copy.thumbnail(self.input_image_size, resample=3)
92
+ return self.transform(copy)
93
+
94
+ @staticmethod
95
+ def data_postprocessing(
96
+ data: torch.tensor, original_image: PIL.Image.Image
97
+ ) -> PIL.Image.Image:
98
+ """
99
+ Transforms output data from neural network to suitable data
100
+ format for using with other components of this framework.
101
+
102
+ Args:
103
+ data: output data from neural network
104
+ original_image: input image which was used for predicted data
105
+
106
+ Returns:
107
+ Segmentation mask as PIL Image instance
108
+
109
+ """
110
+ return (
111
+ Image.fromarray(data.numpy() * 255).convert("L").resize(original_image.size)
112
+ )
113
+
114
+ def __call__(
115
+ self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
116
+ ) -> List[PIL.Image.Image]:
117
+ """
118
+ Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
119
+
120
+ Args:
121
+ images: input images
122
+
123
+ Returns:
124
+ segmentation masks as for input images, as PIL.Image.Image instances
125
+
126
+ """
127
+ collect_masks = []
128
+ autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
129
+ with autocast:
130
+ cast_network(self.network, dtype)
131
+ for image_batch in batch_generator(images, self.batch_size):
132
+ images = thread_pool_processing(
133
+ lambda x: convert_image(load_image(x)), image_batch
134
+ )
135
+ batches = thread_pool_processing(self.data_preprocessing, images)
136
+ with torch.no_grad():
137
+ masks = [
138
+ self.network(i.to(self.device).unsqueeze(0))["out"][0]
139
+ .argmax(0)
140
+ .byte()
141
+ .cpu()
142
+ for i in batches
143
+ ]
144
+ del batches
145
+ masks = thread_pool_processing(
146
+ lambda x: self.data_postprocessing(masks[x], images[x]),
147
+ range(len(images)),
148
+ )
149
+ collect_masks += masks
150
+ return collect_masks
carvekit/ml/wrap/fba_matting.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ import pathlib
7
+ from typing import Union, List, Tuple
8
+
9
+ import PIL
10
+ import cv2
11
+ import numpy as np
12
+ import torch
13
+ from PIL import Image
14
+
15
+ from carvekit.ml.arch.fba_matting.models import FBA
16
+ from carvekit.ml.arch.fba_matting.transforms import (
17
+ trimap_transform,
18
+ groupnorm_normalise_image,
19
+ )
20
+ from carvekit.ml.files.models_loc import fba_pretrained
21
+ from carvekit.utils.image_utils import convert_image, load_image
22
+ from carvekit.utils.models_utils import get_precision_autocast, cast_network
23
+ from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
24
+
25
+ __all__ = ["FBAMatting"]
26
+
27
+
28
+ class FBAMatting(FBA):
29
+ """
30
+ FBA Matting Neural Network to improve edges on image.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ device="cpu",
36
+ input_tensor_size: Union[List[int], int] = 2048,
37
+ batch_size: int = 2,
38
+ encoder="resnet50_GN_WS",
39
+ load_pretrained: bool = True,
40
+ fp16: bool = False,
41
+ ):
42
+ """
43
+ Initialize the FBAMatting model
44
+
45
+ Args:
46
+ device: processing device
47
+ input_tensor_size: input image size
48
+ batch_size: the number of images that the neural network processes in one run
49
+ encoder: neural network encoder head
50
+ load_pretrained: loading pretrained model
51
+ fp16: use half precision
52
+
53
+ """
54
+ super(FBAMatting, self).__init__(encoder=encoder)
55
+ self.fp16 = fp16
56
+ self.device = device
57
+ self.batch_size = batch_size
58
+ if isinstance(input_tensor_size, list):
59
+ self.input_image_size = input_tensor_size[:2]
60
+ else:
61
+ self.input_image_size = (input_tensor_size, input_tensor_size)
62
+ self.to(device)
63
+ if load_pretrained:
64
+ self.load_state_dict(torch.load(fba_pretrained(), map_location=self.device))
65
+ self.eval()
66
+
67
+ def data_preprocessing(
68
+ self, data: Union[PIL.Image.Image, np.ndarray]
69
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
70
+ """
71
+ Transform input image to suitable data format for neural network
72
+
73
+ Args:
74
+ data: input image
75
+
76
+ Returns:
77
+ input for neural network
78
+
79
+ """
80
+ resized = data.copy()
81
+ if self.batch_size == 1:
82
+ resized.thumbnail(self.input_image_size, resample=3)
83
+ else:
84
+ resized = resized.resize(self.input_image_size, resample=3)
85
+ # noinspection PyTypeChecker
86
+ image = np.array(resized, dtype=np.float64)
87
+ image = image / 255.0 # Normalize image to [0, 1] values range
88
+ if resized.mode == "RGB":
89
+ image = image[:, :, ::-1]
90
+ elif resized.mode == "L":
91
+ image2 = np.copy(image)
92
+ h, w = image2.shape
93
+ image = np.zeros((h, w, 2)) # Transform trimap to binary data format
94
+ image[image2 == 1, 1] = 1
95
+ image[image2 == 0, 0] = 1
96
+ else:
97
+ raise ValueError("Incorrect color mode for image")
98
+ h, w = image.shape[:2] # Scale input mlt to 8
99
+ h1 = int(np.ceil(1.0 * h / 8) * 8)
100
+ w1 = int(np.ceil(1.0 * w / 8) * 8)
101
+ x_scale = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_LANCZOS4)
102
+ image_tensor = torch.from_numpy(x_scale).permute(2, 0, 1)[None, :, :, :].float()
103
+ if resized.mode == "RGB":
104
+ return image_tensor, groupnorm_normalise_image(
105
+ image_tensor.clone(), format="nchw"
106
+ )
107
+ else:
108
+ return (
109
+ image_tensor,
110
+ torch.from_numpy(trimap_transform(x_scale))
111
+ .permute(2, 0, 1)[None, :, :, :]
112
+ .float(),
113
+ )
114
+
115
+ @staticmethod
116
+ def data_postprocessing(
117
+ data: torch.tensor, trimap: PIL.Image.Image
118
+ ) -> PIL.Image.Image:
119
+ """
120
+ Transforms output data from neural network to suitable data
121
+ format for using with other components of this framework.
122
+
123
+ Args:
124
+ data: output data from neural network
125
+ trimap: Map with the area we need to refine
126
+
127
+ Returns:
128
+ Segmentation mask as PIL Image instance
129
+
130
+ """
131
+ if trimap.mode != "L":
132
+ raise ValueError("Incorrect color mode for trimap")
133
+ pred = data.numpy().transpose((1, 2, 0))
134
+ pred = cv2.resize(pred, trimap.size, cv2.INTER_LANCZOS4)[:, :, 0]
135
+ # noinspection PyTypeChecker
136
+ # Clean mask by removing all false predictions outside trimap and already known area
137
+ trimap_arr = np.array(trimap.copy())
138
+ pred[trimap_arr[:, :] == 0] = 0
139
+ # pred[trimap_arr[:, :] == 255] = 1
140
+ pred[pred < 0.3] = 0
141
+ return Image.fromarray(pred * 255).convert("L")
142
+
143
+ def __call__(
144
+ self,
145
+ images: List[Union[str, pathlib.Path, PIL.Image.Image]],
146
+ trimaps: List[Union[str, pathlib.Path, PIL.Image.Image]],
147
+ ) -> List[PIL.Image.Image]:
148
+ """
149
+ Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
150
+
151
+ Args:
152
+ images: input images
153
+ trimaps: Maps with the areas we need to refine
154
+
155
+ Returns:
156
+ segmentation masks as for input images, as PIL.Image.Image instances
157
+
158
+ """
159
+
160
+ if len(images) != len(trimaps):
161
+ raise ValueError(
162
+ "Len of specified arrays of images and trimaps should be equal!"
163
+ )
164
+
165
+ collect_masks = []
166
+ autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
167
+ with autocast:
168
+ cast_network(self, dtype)
169
+ for idx_batch in batch_generator(range(len(images)), self.batch_size):
170
+ inpt_images = thread_pool_processing(
171
+ lambda x: convert_image(load_image(images[x])), idx_batch
172
+ )
173
+
174
+ inpt_trimaps = thread_pool_processing(
175
+ lambda x: convert_image(load_image(trimaps[x]), mode="L"), idx_batch
176
+ )
177
+
178
+ inpt_img_batches = thread_pool_processing(
179
+ self.data_preprocessing, inpt_images
180
+ )
181
+ inpt_trimaps_batches = thread_pool_processing(
182
+ self.data_preprocessing, inpt_trimaps
183
+ )
184
+
185
+ inpt_img_batches_transformed = torch.vstack(
186
+ [i[1] for i in inpt_img_batches]
187
+ )
188
+ inpt_img_batches = torch.vstack([i[0] for i in inpt_img_batches])
189
+
190
+ inpt_trimaps_transformed = torch.vstack(
191
+ [i[1] for i in inpt_trimaps_batches]
192
+ )
193
+ inpt_trimaps_batches = torch.vstack(
194
+ [i[0] for i in inpt_trimaps_batches]
195
+ )
196
+
197
+ with torch.no_grad():
198
+ inpt_img_batches = inpt_img_batches.to(self.device)
199
+ inpt_trimaps_batches = inpt_trimaps_batches.to(self.device)
200
+ inpt_img_batches_transformed = inpt_img_batches_transformed.to(
201
+ self.device
202
+ )
203
+ inpt_trimaps_transformed = inpt_trimaps_transformed.to(self.device)
204
+
205
+ output = super(FBAMatting, self).__call__(
206
+ inpt_img_batches,
207
+ inpt_trimaps_batches,
208
+ inpt_img_batches_transformed,
209
+ inpt_trimaps_transformed,
210
+ )
211
+ output_cpu = output.cpu()
212
+ del (
213
+ inpt_img_batches,
214
+ inpt_trimaps_batches,
215
+ inpt_img_batches_transformed,
216
+ inpt_trimaps_transformed,
217
+ output,
218
+ )
219
+ masks = thread_pool_processing(
220
+ lambda x: self.data_postprocessing(output_cpu[x], inpt_trimaps[x]),
221
+ range(len(inpt_images)),
222
+ )
223
+ collect_masks += masks
224
+ return collect_masks
carvekit/ml/wrap/tracer_b7.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ import pathlib
7
+ import warnings
8
+ from typing import List, Union
9
+ import PIL.Image
10
+ import numpy as np
11
+ import torch
12
+ import torchvision.transforms as transforms
13
+ from PIL import Image
14
+
15
+ from carvekit.ml.arch.tracerb7.tracer import TracerDecoder
16
+ from carvekit.ml.arch.tracerb7.efficientnet import EfficientEncoderB7
17
+ from carvekit.ml.files.models_loc import tracer_b7_pretrained, tracer_hair_pretrained
18
+ from carvekit.utils.models_utils import get_precision_autocast, cast_network
19
+ from carvekit.utils.image_utils import load_image, convert_image
20
+ from carvekit.utils.pool_utils import thread_pool_processing, batch_generator
21
+
22
+ __all__ = ["TracerUniversalB7"]
23
+
24
+
25
+ class TracerUniversalB7(TracerDecoder):
26
+ """TRACER B7 model interface"""
27
+
28
+ def __init__(
29
+ self,
30
+ device="cpu",
31
+ input_image_size: Union[List[int], int] = 640,
32
+ batch_size: int = 4,
33
+ load_pretrained: bool = True,
34
+ fp16: bool = False,
35
+ model_path: Union[str, pathlib.Path] = None,
36
+ ):
37
+ """
38
+ Initialize the U2NET model
39
+
40
+ Args:
41
+ layers_cfg: neural network layers configuration
42
+ device: processing device
43
+ input_image_size: input image size
44
+ batch_size: the number of images that the neural network processes in one run
45
+ load_pretrained: loading pretrained model
46
+ fp16: use fp16 precision
47
+
48
+ """
49
+ if model_path is None:
50
+ model_path = tracer_b7_pretrained()
51
+ super(TracerUniversalB7, self).__init__(
52
+ encoder=EfficientEncoderB7(),
53
+ rfb_channel=[32, 64, 128],
54
+ features_channels=[48, 80, 224, 640],
55
+ )
56
+
57
+ self.fp16 = fp16
58
+ self.device = device
59
+ self.batch_size = batch_size
60
+ if isinstance(input_image_size, list):
61
+ self.input_image_size = input_image_size[:2]
62
+ else:
63
+ self.input_image_size = (input_image_size, input_image_size)
64
+
65
+ self.transform = transforms.Compose(
66
+ [
67
+ transforms.ToTensor(),
68
+ transforms.Resize(self.input_image_size),
69
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
70
+ ]
71
+ )
72
+ self.to(device)
73
+ if load_pretrained:
74
+ # TODO remove edge detector from weights. It doesn't work well with this model!
75
+ self.load_state_dict(
76
+ torch.load(model_path, map_location=self.device), strict=False
77
+ )
78
+ self.eval()
79
+
80
+ def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
81
+ """
82
+ Transform input image to suitable data format for neural network
83
+
84
+ Args:
85
+ data: input image
86
+
87
+ Returns:
88
+ input for neural network
89
+
90
+ """
91
+
92
+ return torch.unsqueeze(self.transform(data), 0).type(torch.FloatTensor)
93
+
94
+ @staticmethod
95
+ def data_postprocessing(
96
+ data: torch.tensor, original_image: PIL.Image.Image
97
+ ) -> PIL.Image.Image:
98
+ """
99
+ Transforms output data from neural network to suitable data
100
+ format for using with other components of this framework.
101
+
102
+ Args:
103
+ data: output data from neural network
104
+ original_image: input image which was used for predicted data
105
+
106
+ Returns:
107
+ Segmentation mask as PIL Image instance
108
+
109
+ """
110
+ output = (data.type(torch.FloatTensor).detach().cpu().numpy() * 255.0).astype(
111
+ np.uint8
112
+ )
113
+ output = output.squeeze(0)
114
+ mask = Image.fromarray(output).convert("L")
115
+ mask = mask.resize(original_image.size, resample=Image.BILINEAR)
116
+ return mask
117
+
118
+ def __call__(
119
+ self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
120
+ ) -> List[PIL.Image.Image]:
121
+ """
122
+ Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
123
+
124
+ Args:
125
+ images: input images
126
+
127
+ Returns:
128
+ segmentation masks as for input images, as PIL.Image.Image instances
129
+
130
+ """
131
+ collect_masks = []
132
+ autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
133
+ with autocast:
134
+ cast_network(self, dtype)
135
+ for image_batch in batch_generator(images, self.batch_size):
136
+ images = thread_pool_processing(
137
+ lambda x: convert_image(load_image(x)), image_batch
138
+ )
139
+ batches = torch.vstack(
140
+ thread_pool_processing(self.data_preprocessing, images)
141
+ )
142
+ with torch.no_grad():
143
+ batches = batches.to(self.device)
144
+ masks = super(TracerDecoder, self).__call__(batches)
145
+ masks_cpu = masks.cpu()
146
+ del batches, masks
147
+ masks = thread_pool_processing(
148
+ lambda x: self.data_postprocessing(masks_cpu[x], images[x]),
149
+ range(len(images)),
150
+ )
151
+ collect_masks += masks
152
+
153
+ return collect_masks
154
+
155
+
156
+ class TracerHair(TracerUniversalB7):
157
+ """TRACER HAIR model interface"""
158
+
159
+ def __init__(
160
+ self,
161
+ device="cpu",
162
+ input_image_size: Union[List[int], int] = 640,
163
+ batch_size: int = 4,
164
+ load_pretrained: bool = True,
165
+ fp16: bool = False,
166
+ model_path: Union[str, pathlib.Path] = None,
167
+ ):
168
+ if model_path is None:
169
+ model_path = tracer_hair_pretrained()
170
+ warnings.warn("TracerHair has not public model yet. Don't use it!", UserWarning)
171
+ super(TracerHair, self).__init__(
172
+ device=device,
173
+ input_image_size=input_image_size,
174
+ batch_size=batch_size,
175
+ load_pretrained=load_pretrained,
176
+ fp16=fp16,
177
+ model_path=model_path,
178
+ )
carvekit/ml/wrap/u2net.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ import pathlib
7
+ from typing import List, Union
8
+ import PIL.Image
9
+ import numpy as np
10
+ import torch
11
+ from PIL import Image
12
+
13
+ from carvekit.ml.arch.u2net.u2net import U2NETArchitecture
14
+ from carvekit.ml.files.models_loc import u2net_full_pretrained
15
+ from carvekit.utils.image_utils import load_image, convert_image
16
+ from carvekit.utils.pool_utils import thread_pool_processing, batch_generator
17
+
18
+ __all__ = ["U2NET"]
19
+
20
+
21
+ class U2NET(U2NETArchitecture):
22
+ """U^2-Net model interface"""
23
+
24
+ def __init__(
25
+ self,
26
+ layers_cfg="full",
27
+ device="cpu",
28
+ input_image_size: Union[List[int], int] = 320,
29
+ batch_size: int = 10,
30
+ load_pretrained: bool = True,
31
+ fp16: bool = False,
32
+ ):
33
+ """
34
+ Initialize the U2NET model
35
+
36
+ Args:
37
+ layers_cfg: neural network layers configuration
38
+ device: processing device
39
+ input_image_size: input image size
40
+ batch_size: the number of images that the neural network processes in one run
41
+ load_pretrained: loading pretrained model
42
+ fp16: use fp16 precision // not supported at this moment.
43
+
44
+ """
45
+ super(U2NET, self).__init__(cfg_type=layers_cfg, out_ch=1)
46
+ self.device = device
47
+ self.batch_size = batch_size
48
+ if isinstance(input_image_size, list):
49
+ self.input_image_size = input_image_size[:2]
50
+ else:
51
+ self.input_image_size = (input_image_size, input_image_size)
52
+ self.to(device)
53
+ if load_pretrained:
54
+ self.load_state_dict(
55
+ torch.load(u2net_full_pretrained(), map_location=self.device)
56
+ )
57
+ self.eval()
58
+
59
+ def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:
60
+ """
61
+ Transform input image to suitable data format for neural network
62
+
63
+ Args:
64
+ data: input image
65
+
66
+ Returns:
67
+ input for neural network
68
+
69
+ """
70
+ resized = data.resize(self.input_image_size, resample=3)
71
+ # noinspection PyTypeChecker
72
+ resized_arr = np.array(resized, dtype=float)
73
+ temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3))
74
+ if np.max(resized_arr) != 0:
75
+ resized_arr /= np.max(resized_arr)
76
+ temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229
77
+ temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224
78
+ temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225
79
+ temp_image = temp_image.transpose((2, 0, 1))
80
+ temp_image = np.expand_dims(temp_image, 0)
81
+ return torch.from_numpy(temp_image).type(torch.FloatTensor)
82
+
83
+ @staticmethod
84
+ def data_postprocessing(
85
+ data: torch.tensor, original_image: PIL.Image.Image
86
+ ) -> PIL.Image.Image:
87
+ """
88
+ Transforms output data from neural network to suitable data
89
+ format for using with other components of this framework.
90
+
91
+ Args:
92
+ data: output data from neural network
93
+ original_image: input image which was used for predicted data
94
+
95
+ Returns:
96
+ Segmentation mask as PIL Image instance
97
+
98
+ """
99
+ data = data.unsqueeze(0)
100
+ mask = data[:, 0, :, :]
101
+ ma = torch.max(mask) # Normalizes prediction
102
+ mi = torch.min(mask)
103
+ predict = ((mask - mi) / (ma - mi)).squeeze()
104
+ predict_np = predict.cpu().data.numpy() * 255
105
+ mask = Image.fromarray(predict_np).convert("L")
106
+ mask = mask.resize(original_image.size, resample=3)
107
+ return mask
108
+
109
+ def __call__(
110
+ self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
111
+ ) -> List[PIL.Image.Image]:
112
+ """
113
+ Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
114
+
115
+ Args:
116
+ images: input images
117
+
118
+ Returns:
119
+ segmentation masks as for input images, as PIL.Image.Image instances
120
+
121
+ """
122
+ collect_masks = []
123
+ for image_batch in batch_generator(images, self.batch_size):
124
+ images = thread_pool_processing(
125
+ lambda x: convert_image(load_image(x)), image_batch
126
+ )
127
+ batches = torch.vstack(
128
+ thread_pool_processing(self.data_preprocessing, images)
129
+ )
130
+ with torch.no_grad():
131
+ batches = batches.to(self.device)
132
+ masks, d2, d3, d4, d5, d6, d7 = super(U2NET, self).__call__(batches)
133
+ masks_cpu = masks.cpu()
134
+ del d2, d3, d4, d5, d6, d7, batches, masks
135
+ masks = thread_pool_processing(
136
+ lambda x: self.data_postprocessing(masks_cpu[x], images[x]),
137
+ range(len(images)),
138
+ )
139
+ collect_masks += masks
140
+ return collect_masks
carvekit/pipelines/__init__.py ADDED
File without changes
carvekit/pipelines/postprocessing.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ from carvekit.ml.wrap.fba_matting import FBAMatting
7
+ from typing import Union, List
8
+ from PIL import Image
9
+ from pathlib import Path
10
+ from carvekit.trimap.cv_gen import CV2TrimapGenerator
11
+ from carvekit.trimap.generator import TrimapGenerator
12
+ from carvekit.utils.mask_utils import apply_mask
13
+ from carvekit.utils.pool_utils import thread_pool_processing
14
+ from carvekit.utils.image_utils import load_image, convert_image
15
+
16
+ __all__ = ["MattingMethod"]
17
+
18
+
19
+ class MattingMethod:
20
+ """
21
+ Improving the edges of the object mask using neural networks for matting and algorithms for creating trimap.
22
+ Neural network for matting performs accurate object edge detection by using a special map called trimap,
23
+ with unknown area that we scan for boundary, already known general object area and the background."""
24
+
25
+ def __init__(
26
+ self,
27
+ matting_module: Union[FBAMatting],
28
+ trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator],
29
+ device="cpu",
30
+ ):
31
+ """
32
+ Initializes Matting Method class.
33
+
34
+ Args:
35
+ matting_module: Initialized matting neural network class
36
+ trimap_generator: Initialized trimap generator class
37
+ device: Processing device used for applying mask to image
38
+ """
39
+ self.device = device
40
+ self.matting_module = matting_module
41
+ self.trimap_generator = trimap_generator
42
+
43
+ def __call__(
44
+ self,
45
+ images: List[Union[str, Path, Image.Image]],
46
+ masks: List[Union[str, Path, Image.Image]],
47
+ ):
48
+ """
49
+ Passes data through apply_mask function
50
+
51
+ Args:
52
+ images: list of images
53
+ masks: list pf masks
54
+
55
+ Returns:
56
+ list of images
57
+ """
58
+ if len(images) != len(masks):
59
+ raise ValueError("Images and Masks lists should have same length!")
60
+ images = thread_pool_processing(lambda x: convert_image(load_image(x)), images)
61
+ masks = thread_pool_processing(
62
+ lambda x: convert_image(load_image(x), mode="L"), masks
63
+ )
64
+ trimaps = thread_pool_processing(
65
+ lambda x: self.trimap_generator(original_image=images[x], mask=masks[x]),
66
+ range(len(images)),
67
+ )
68
+ alpha = self.matting_module(images=images, trimaps=trimaps)
69
+ return list(
70
+ map(
71
+ lambda x: apply_mask(
72
+ image=images[x], mask=alpha[x], device=self.device
73
+ ),
74
+ range(len(images)),
75
+ )
76
+ )
carvekit/pipelines/preprocessing.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ from pathlib import Path
7
+ from typing import Union, List
8
+
9
+ from PIL import Image
10
+
11
+ __all__ = ["PreprocessingStub"]
12
+
13
+
14
+ class PreprocessingStub:
15
+ """Stub for future preprocessing methods"""
16
+
17
+ def __call__(self, interface, images: List[Union[str, Path, Image.Image]]):
18
+ """
19
+ Passes data though interface.segmentation_pipeline() method
20
+
21
+ Args:
22
+ interface: Interface instance
23
+ images: list of images
24
+
25
+ Returns:
26
+ the result of passing data through segmentation_pipeline method of interface
27
+ """
28
+ return interface.segmentation_pipeline(images=images)
carvekit/trimap/__init__.py ADDED
File without changes
carvekit/trimap/add_ops.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+
11
+ def prob_filter(mask: Image.Image, prob_threshold=231) -> Image.Image:
12
+ """
13
+ Applies a filter to the mask by the probability of locating an object in the object area.
14
+
15
+ Args:
16
+ prob_threshold: Threshold of probability for mark area as background.
17
+ mask: Predicted object mask
18
+
19
+ Raises:
20
+ ValueError if mask or trimap has wrong color mode
21
+
22
+ Returns:
23
+ Generated trimap for image.
24
+ """
25
+ if mask.mode != "L":
26
+ raise ValueError("Input mask has wrong color mode.")
27
+ # noinspection PyTypeChecker
28
+ mask_array = np.array(mask)
29
+ mask_array[mask_array > prob_threshold] = 255 # Probability filter for mask
30
+ mask_array[mask_array <= prob_threshold] = 0
31
+ return Image.fromarray(mask_array).convert("L")
32
+
33
+
34
+ def prob_as_unknown_area(
35
+ trimap: Image.Image, mask: Image.Image, prob_threshold=255
36
+ ) -> Image.Image:
37
+ """
38
+ Marks any uncertainty in the seg mask as an unknown region.
39
+
40
+ Args:
41
+ prob_threshold: Threshold of probability for mark area as unknown.
42
+ trimap: Generated trimap.
43
+ mask: Predicted object mask
44
+
45
+ Raises:
46
+ ValueError if mask or trimap has wrong color mode
47
+
48
+ Returns:
49
+ Generated trimap for image.
50
+ """
51
+ if mask.mode != "L" or trimap.mode != "L":
52
+ raise ValueError("Input mask has wrong color mode.")
53
+ # noinspection PyTypeChecker
54
+ mask_array = np.array(mask)
55
+ # noinspection PyTypeChecker
56
+ trimap_array = np.array(trimap)
57
+ trimap_array[np.logical_and(mask_array <= prob_threshold, mask_array > 0)] = 127
58
+ return Image.fromarray(trimap_array).convert("L")
59
+
60
+
61
+ def post_erosion(trimap: Image.Image, erosion_iters=1) -> Image.Image:
62
+ """
63
+ Performs erosion on the mask and marks the resulting area as an unknown region.
64
+
65
+ Args:
66
+ erosion_iters: The number of iterations of erosion that
67
+ the object's mask will be subjected to before forming an unknown area
68
+ trimap: Generated trimap.
69
+ mask: Predicted object mask
70
+
71
+ Returns:
72
+ Generated trimap for image.
73
+ """
74
+ if trimap.mode != "L":
75
+ raise ValueError("Input mask has wrong color mode.")
76
+ # noinspection PyTypeChecker
77
+ trimap_array = np.array(trimap)
78
+ if erosion_iters > 0:
79
+ without_unknown_area = trimap_array.copy()
80
+ without_unknown_area[without_unknown_area == 127] = 0
81
+
82
+ erosion_kernel = np.ones((3, 3), np.uint8)
83
+ erode = cv2.erode(
84
+ without_unknown_area, erosion_kernel, iterations=erosion_iters
85
+ )
86
+ erode = np.where(erode == 0, 0, without_unknown_area)
87
+ trimap_array[np.logical_and(erode == 0, without_unknown_area > 0)] = 127
88
+ erode = trimap_array.copy()
89
+ else:
90
+ erode = trimap_array.copy()
91
+ return Image.fromarray(erode).convert("L")
carvekit/trimap/cv_gen.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ import PIL.Image
7
+ import cv2
8
+ import numpy as np
9
+
10
+
11
+ class CV2TrimapGenerator:
12
+ def __init__(self, kernel_size: int = 30, erosion_iters: int = 1):
13
+ """
14
+ Initialize a new CV2TrimapGenerator instance
15
+
16
+ Args:
17
+ kernel_size: The size of the offset from the object mask
18
+ in pixels when an unknown area is detected in the trimap
19
+ erosion_iters: The number of iterations of erosion that
20
+ the object's mask will be subjected to before forming an unknown area
21
+ """
22
+ self.kernel_size = kernel_size
23
+ self.erosion_iters = erosion_iters
24
+
25
+ def __call__(
26
+ self, original_image: PIL.Image.Image, mask: PIL.Image.Image
27
+ ) -> PIL.Image.Image:
28
+ """
29
+ Generates trimap based on predicted object mask to refine object mask borders.
30
+ Based on cv2 erosion algorithm.
31
+
32
+ Args:
33
+ original_image: Original image
34
+ mask: Predicted object mask
35
+
36
+ Returns:
37
+ Generated trimap for image.
38
+ """
39
+ if mask.mode != "L":
40
+ raise ValueError("Input mask has wrong color mode.")
41
+ if mask.size != original_image.size:
42
+ raise ValueError("Sizes of input image and predicted mask doesn't equal")
43
+ # noinspection PyTypeChecker
44
+ mask_array = np.array(mask)
45
+ pixels = 2 * self.kernel_size + 1
46
+ kernel = np.ones((pixels, pixels), np.uint8)
47
+
48
+ if self.erosion_iters > 0:
49
+ erosion_kernel = np.ones((3, 3), np.uint8)
50
+ erode = cv2.erode(mask_array, erosion_kernel, iterations=self.erosion_iters)
51
+ erode = np.where(erode == 0, 0, mask_array)
52
+ else:
53
+ erode = mask_array.copy()
54
+
55
+ dilation = cv2.dilate(erode, kernel, iterations=1)
56
+
57
+ dilation = np.where(dilation == 255, 127, dilation) # WHITE to GRAY
58
+ trimap = np.where(erode > 127, 200, dilation) # mark the tumor inside GRAY
59
+
60
+ trimap = np.where(trimap < 127, 0, trimap) # Embelishment
61
+ trimap = np.where(trimap > 200, 0, trimap) # Embelishment
62
+ trimap = np.where(trimap == 200, 255, trimap) # GRAY to WHITE
63
+
64
+ return PIL.Image.fromarray(trimap).convert("L")
carvekit/trimap/generator.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ from PIL import Image
7
+ from carvekit.trimap.cv_gen import CV2TrimapGenerator
8
+ from carvekit.trimap.add_ops import prob_filter, prob_as_unknown_area, post_erosion
9
+
10
+
11
+ class TrimapGenerator(CV2TrimapGenerator):
12
+ def __init__(
13
+ self, prob_threshold: int = 231, kernel_size: int = 30, erosion_iters: int = 5
14
+ ):
15
+ """
16
+ Initialize a TrimapGenerator instance
17
+
18
+ Args:
19
+ prob_threshold: Probability threshold at which the
20
+ prob_filter and prob_as_unknown_area operations will be applied
21
+ kernel_size: The size of the offset from the object mask
22
+ in pixels when an unknown area is detected in the trimap
23
+ erosion_iters: The number of iterations of erosion that
24
+ the object's mask will be subjected to before forming an unknown area
25
+ """
26
+ super().__init__(kernel_size, erosion_iters=0)
27
+ self.prob_threshold = prob_threshold
28
+ self.__erosion_iters = erosion_iters
29
+
30
+ def __call__(self, original_image: Image.Image, mask: Image.Image) -> Image.Image:
31
+ """
32
+ Generates trimap based on predicted object mask to refine object mask borders.
33
+ Based on cv2 erosion algorithm and additional prob. filters.
34
+ Args:
35
+ original_image: Original image
36
+ mask: Predicted object mask
37
+
38
+ Returns:
39
+ Generated trimap for image.
40
+ """
41
+ filter_mask = prob_filter(mask=mask, prob_threshold=self.prob_threshold)
42
+ trimap = super(TrimapGenerator, self).__call__(original_image, filter_mask)
43
+ new_trimap = prob_as_unknown_area(
44
+ trimap=trimap, mask=mask, prob_threshold=self.prob_threshold
45
+ )
46
+ new_trimap = post_erosion(new_trimap, self.__erosion_iters)
47
+ return new_trimap
carvekit/utils/__init__.py ADDED
File without changes
carvekit/utils/download_models.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ import hashlib
7
+ import os
8
+ import warnings
9
+ from abc import ABCMeta, abstractmethod, ABC
10
+ from pathlib import Path
11
+ from typing import Optional
12
+
13
+ import carvekit
14
+ from carvekit.ml.files import checkpoints_dir
15
+
16
+ import requests
17
+ import tqdm
18
+
19
+ requests = requests.Session()
20
+ requests.headers.update({"User-Agent": f"Carvekit/{carvekit.version}"})
21
+
22
+ MODELS_URLS = {
23
+ "basnet.pth": {
24
+ "repository": "Carve/basnet-universal",
25
+ "revision": "870becbdb364fda6d8fdb2c10b072542f8d08701",
26
+ "filename": "basnet.pth",
27
+ },
28
+ "deeplab.pth": {
29
+ "repository": "Carve/deeplabv3-resnet101",
30
+ "revision": "d504005392fc877565afdf58aad0cd524682d2b0",
31
+ "filename": "deeplab.pth",
32
+ },
33
+ "fba_matting.pth": {
34
+ "repository": "Carve/fba",
35
+ "revision": "a5d3457df0fb9c88ea19ed700d409756ca2069d1",
36
+ "filename": "fba_matting.pth",
37
+ },
38
+ "u2net.pth": {
39
+ "repository": "Carve/u2net-universal",
40
+ "revision": "10305d785481cf4b2eee1d447c39cd6e5f43d74b",
41
+ "filename": "full_weights.pth",
42
+ },
43
+ "tracer_b7.pth": {
44
+ "repository": "Carve/tracer_b7",
45
+ "revision": "d8a8fd9e7b3fa0d2f1506fe7242966b34381e9c5",
46
+ "filename": "tracer_b7.pth",
47
+ },
48
+ "tracer_hair.pth": {
49
+ "repository": "Carve/tracer_b7",
50
+ "revision": "d8a8fd9e7b3fa0d2f1506fe7242966b34381e9c5",
51
+ "filename": "tracer_b7.pth", # TODO don't forget change this link!!
52
+ },
53
+ }
54
+
55
+ MODELS_CHECKSUMS = {
56
+ "basnet.pth": "e409cb709f4abca87cb11bd44a9ad3f909044a917977ab65244b4c94dd33"
57
+ "8b1a37755c4253d7cb54526b7763622a094d7b676d34b5e6886689256754e5a5e6ad",
58
+ "deeplab.pth": "9c5a1795bc8baa267200a44b49ac544a1ba2687d210f63777e4bd715387324469a59b072f8a28"
59
+ "9cc471c637b367932177e5b312e8ea6351c1763d9ff44b4857c",
60
+ "fba_matting.pth": "890906ec94c1bfd2ad08707a63e4ccb0955d7f5d25e32853950c24c78"
61
+ "4cbad2e59be277999defc3754905d0f15aa75702cdead3cfe669ff72f08811c52971613",
62
+ "u2net.pth": "16f8125e2fedd8c85db0e001ee15338b4aa2fda77bab8ba70c25e"
63
+ "bea1533fda5ee70a909b934a9bd495b432cef89d629f00a07858a517742476fa8b346de24f7",
64
+ "tracer_b7.pth": "c439c5c12d4d43d5f9be9ec61e68b2e54658a541bccac2577ef5a54fb252b6e8415d41f7e"
65
+ "c2487033d0c02b4dd08367958e4e62091318111c519f93e2632be7b",
66
+ "tracer_hair.pth": "5c2fb9973fc42fa6208920ffa9ac233cc2ea9f770b24b4a96969d3449aed7ac89e6d37e"
67
+ "e486a13e63be5499f2df6ccef1109e9e8797d1326207ac89b2f39a7cf",
68
+ }
69
+
70
+
71
+ def sha512_checksum_calc(file: Path) -> str:
72
+ """
73
+ Calculates the SHA512 hash digest of a file on fs
74
+
75
+ Args:
76
+ file: Path to the file
77
+
78
+ Returns:
79
+ SHA512 hash digest of a file.
80
+ """
81
+ dd = hashlib.sha512()
82
+ with file.open("rb") as f:
83
+ for chunk in iter(lambda: f.read(4096), b""):
84
+ dd.update(chunk)
85
+ return dd.hexdigest()
86
+
87
+
88
+ class CachedDownloader:
89
+ __metaclass__ = ABCMeta
90
+
91
+ @property
92
+ @abstractmethod
93
+ def name(self) -> str:
94
+ return self.__class__.__name__
95
+
96
+ @property
97
+ @abstractmethod
98
+ def fallback_downloader(self) -> Optional["CachedDownloader"]:
99
+ pass
100
+
101
+ def download_model(self, file_name: str) -> Path:
102
+ try:
103
+ return self.download_model_base(file_name)
104
+ except BaseException as e:
105
+ if self.fallback_downloader is not None:
106
+ warnings.warn(
107
+ f"Failed to download model from {self.name} downloader."
108
+ f" Trying to download from {self.fallback_downloader.name} downloader."
109
+ )
110
+ return self.fallback_downloader.download_model(file_name)
111
+ else:
112
+ warnings.warn(
113
+ f"Failed to download model from {self.name} downloader."
114
+ f" No fallback downloader available."
115
+ )
116
+ raise e
117
+
118
+ @abstractmethod
119
+ def download_model_base(self, file_name: str) -> Path:
120
+ """Download model from any source if not cached. Returns path if cached"""
121
+
122
+ def __call__(self, file_name: str):
123
+ return self.download_model(file_name)
124
+
125
+
126
+ class HuggingFaceCompatibleDownloader(CachedDownloader, ABC):
127
+ def __init__(
128
+ self,
129
+ name: str = "Huggingface.co",
130
+ base_url: str = "https://huggingface.co",
131
+ fb_downloader: Optional["CachedDownloader"] = None,
132
+ ):
133
+ self.cache_dir = checkpoints_dir
134
+ self.base_url = base_url
135
+ self._name = name
136
+ self._fallback_downloader = fb_downloader
137
+
138
+ @property
139
+ def fallback_downloader(self) -> Optional["CachedDownloader"]:
140
+ return self._fallback_downloader
141
+
142
+ @property
143
+ def name(self):
144
+ return self._name
145
+
146
+ def check_for_existence(self, file_name: str) -> Optional[Path]:
147
+ if file_name not in MODELS_URLS.keys():
148
+ raise FileNotFoundError("Unknown model!")
149
+ path = (
150
+ self.cache_dir
151
+ / MODELS_URLS[file_name]["repository"].split("/")[1]
152
+ / file_name
153
+ )
154
+
155
+ if not path.exists():
156
+ return None
157
+
158
+ if MODELS_CHECKSUMS[path.name] != sha512_checksum_calc(path):
159
+ warnings.warn(
160
+ f"Invalid checksum for model {path.name}. Downloading correct model!"
161
+ )
162
+ os.remove(path)
163
+ return None
164
+ return path
165
+
166
+ def download_model_base(self, file_name: str) -> Path:
167
+ cached_path = self.check_for_existence(file_name)
168
+ if cached_path is not None:
169
+ return cached_path
170
+ else:
171
+ cached_path = (
172
+ self.cache_dir
173
+ / MODELS_URLS[file_name]["repository"].split("/")[1]
174
+ / file_name
175
+ )
176
+ cached_path.parent.mkdir(parents=True, exist_ok=True)
177
+ url = MODELS_URLS[file_name]
178
+ hugging_face_url = f"{self.base_url}/{url['repository']}/resolve/{url['revision']}/{url['filename']}"
179
+
180
+ try:
181
+ r = requests.get(hugging_face_url, stream=True, timeout=10)
182
+ if r.status_code < 400:
183
+ with open(cached_path, "wb") as f:
184
+ r.raw.decode_content = True
185
+ for chunk in tqdm.tqdm(
186
+ r,
187
+ desc="Downloading " + cached_path.name + " model",
188
+ colour="blue",
189
+ ):
190
+ f.write(chunk)
191
+ else:
192
+ if r.status_code == 404:
193
+ raise FileNotFoundError(f"Model {file_name} not found!")
194
+ else:
195
+ raise ConnectionError(
196
+ f"Error {r.status_code} while downloading model {file_name}!"
197
+ )
198
+ except BaseException as e:
199
+ if cached_path.exists():
200
+ os.remove(cached_path)
201
+ raise ConnectionError(
202
+ f"Exception caught when downloading model! "
203
+ f"Model name: {cached_path.name}. Exception: {str(e)}."
204
+ )
205
+ return cached_path
206
+
207
+
208
+ fallback_downloader: CachedDownloader = HuggingFaceCompatibleDownloader()
209
+ downloader: CachedDownloader = HuggingFaceCompatibleDownloader(
210
+ base_url="https://cdn.carve.photos",
211
+ fb_downloader=fallback_downloader,
212
+ name="Carve CDN",
213
+ )
214
+ downloader._fallback_downloader = fallback_downloader
carvekit/utils/fs_utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ from pathlib import Path
7
+ from PIL import Image
8
+ import warnings
9
+ from typing import Optional
10
+
11
+
12
+ def save_file(output: Optional[Path], input_path: Path, image: Image.Image):
13
+ """
14
+ Saves an image to the file system
15
+
16
+ Args:
17
+ output: Output path [dir or end file]
18
+ input_path: Input path of the image
19
+ image: Image to be saved.
20
+ """
21
+ if isinstance(output, Path) and str(output) != "none":
22
+ if output.is_dir() and output.exists():
23
+ image.save(output.joinpath(input_path.with_suffix(".png").name))
24
+ elif output.suffix != "":
25
+ if output.suffix != ".png":
26
+ warnings.warn(
27
+ f"Only export with .png extension is supported! Your {output.suffix}"
28
+ f" extension will be ignored and replaced with .png!"
29
+ )
30
+ image.save(output.with_suffix(".png"))
31
+ else:
32
+ raise ValueError("Wrong output path!")
33
+ elif output is None or str(output) == "none":
34
+ image.save(
35
+ input_path.with_name(
36
+ input_path.stem.split(".")[0] + "_bg_removed"
37
+ ).with_suffix(".png")
38
+ )
carvekit/utils/image_utils.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+
7
+ import pathlib
8
+ from typing import Union, Any, Tuple
9
+
10
+ import PIL.Image
11
+ import numpy as np
12
+ import torch
13
+
14
+ ALLOWED_SUFFIXES = [".jpg", ".jpeg", ".bmp", ".png", ".webp"]
15
+
16
+
17
+ def to_tensor(x: Any) -> torch.Tensor:
18
+ """
19
+ Returns a PIL.Image.Image as torch tensor without swap tensor dims.
20
+
21
+ Args:
22
+ x: PIL.Image.Image instance
23
+
24
+ Returns:
25
+ torch.Tensor instance
26
+ """
27
+ return torch.tensor(np.array(x, copy=True))
28
+
29
+
30
+ def load_image(file: Union[str, pathlib.Path, PIL.Image.Image]) -> PIL.Image.Image:
31
+ """Returns a PIL.Image.Image class by string path or pathlib path or PIL.Image.Image instance
32
+
33
+ Args:
34
+ file: File path or PIL.Image.Image instance
35
+
36
+ Returns:
37
+ PIL.Image.Image instance
38
+
39
+ Raises:
40
+ ValueError: If file not exists or file is directory or file isn't an image or file is not correct PIL Image
41
+
42
+ """
43
+ if isinstance(file, str) and is_image_valid(pathlib.Path(file)):
44
+ return PIL.Image.open(file)
45
+ elif isinstance(file, PIL.Image.Image):
46
+ return file
47
+ elif isinstance(file, pathlib.Path) and is_image_valid(file):
48
+ return PIL.Image.open(str(file))
49
+ else:
50
+ raise ValueError("Unknown input file type")
51
+
52
+
53
+ def convert_image(image: PIL.Image.Image, mode="RGB") -> PIL.Image.Image:
54
+ """Performs image conversion to correct color mode
55
+
56
+ Args:
57
+ image: PIL.Image.Image instance
58
+ mode: Colort Mode to convert
59
+
60
+ Returns:
61
+ PIL.Image.Image instance
62
+
63
+ Raises:
64
+ ValueError: If image hasn't convertable color mode, or it is too small
65
+ """
66
+ if is_image_valid(image):
67
+ return image.convert(mode)
68
+
69
+
70
+ def is_image_valid(image: Union[pathlib.Path, PIL.Image.Image]) -> bool:
71
+ """This function performs image validation.
72
+
73
+ Args:
74
+ image: Path to the image or PIL.Image.Image instance being checked.
75
+
76
+ Returns:
77
+ True if image is valid
78
+
79
+ Raises:
80
+ ValueError: If file not a valid image path or image hasn't convertable color mode, or it is too small
81
+
82
+ """
83
+ if isinstance(image, pathlib.Path):
84
+ if not image.exists():
85
+ raise ValueError("File is not exists")
86
+ elif image.is_dir():
87
+ raise ValueError("File is a directory")
88
+ elif image.suffix.lower() not in ALLOWED_SUFFIXES:
89
+ raise ValueError(
90
+ f"Unsupported image format. Supported file formats: {', '.join(ALLOWED_SUFFIXES)}"
91
+ )
92
+ elif isinstance(image, PIL.Image.Image):
93
+ if not (image.size[0] > 32 and image.size[1] > 32):
94
+ raise ValueError("Image should be bigger then (32x32) pixels.")
95
+ elif image.mode not in ["RGB", "RGBA", "L"]:
96
+ raise ValueError("Wrong image color mode.")
97
+ else:
98
+ raise ValueError("Unknown input file type")
99
+ return True
100
+
101
+
102
+ def transparency_paste(
103
+ bg_img: PIL.Image.Image, fg_img: PIL.Image.Image, box=(0, 0)
104
+ ) -> PIL.Image.Image:
105
+ """
106
+ Inserts an image into another image while maintaining transparency.
107
+
108
+ Args:
109
+ bg_img: background image
110
+ fg_img: foreground image
111
+ box: place to paste
112
+
113
+ Returns:
114
+ Background image with pasted foreground image at point or in the specified box
115
+ """
116
+ fg_img_trans = PIL.Image.new("RGBA", bg_img.size)
117
+ fg_img_trans.paste(fg_img, box, mask=fg_img)
118
+ new_img = PIL.Image.alpha_composite(bg_img, fg_img_trans)
119
+ return new_img
120
+
121
+
122
+ def add_margin(
123
+ pil_img: PIL.Image.Image,
124
+ top: int,
125
+ right: int,
126
+ bottom: int,
127
+ left: int,
128
+ color: Tuple[int, int, int, int],
129
+ ) -> PIL.Image.Image:
130
+ """
131
+ Adds margin to the image.
132
+
133
+ Args:
134
+ pil_img: Image that needed to add margin.
135
+ top: pixels count at top side
136
+ right: pixels count at right side
137
+ bottom: pixels count at bottom side
138
+ left: pixels count at left side
139
+ color: color of margin
140
+
141
+ Returns:
142
+ Image with margin.
143
+ """
144
+ width, height = pil_img.size
145
+ new_width = width + right + left
146
+ new_height = height + top + bottom
147
+ # noinspection PyTypeChecker
148
+ result = PIL.Image.new(pil_img.mode, (new_width, new_height), color)
149
+ result.paste(pil_img, (left, top))
150
+ return result
carvekit/utils/mask_utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ import PIL.Image
7
+ import torch
8
+ from carvekit.utils.image_utils import to_tensor
9
+
10
+
11
+ def composite(
12
+ foreground: PIL.Image.Image,
13
+ background: PIL.Image.Image,
14
+ alpha: PIL.Image.Image,
15
+ device="cpu",
16
+ ):
17
+ """
18
+ Composites foreground with background by following
19
+ https://pymatting.github.io/intro.html#alpha-matting math formula.
20
+
21
+ Args:
22
+ device: Processing device
23
+ foreground: Image that will be pasted to background image with following alpha mask.
24
+ background: Background image
25
+ alpha: Alpha Image
26
+
27
+ Returns:
28
+ Composited image as PIL.Image instance.
29
+ """
30
+
31
+ foreground = foreground.convert("RGBA")
32
+ background = background.convert("RGBA")
33
+ alpha_rgba = alpha.convert("RGBA")
34
+ alpha_l = alpha.convert("L")
35
+
36
+ fg = to_tensor(foreground).to(device)
37
+ alpha_rgba = to_tensor(alpha_rgba).to(device)
38
+ alpha_l = to_tensor(alpha_l).to(device)
39
+ bg = to_tensor(background).to(device)
40
+
41
+ alpha_l = alpha_l / 255
42
+ alpha_rgba = alpha_rgba / 255
43
+
44
+ bg = torch.where(torch.logical_not(alpha_rgba >= 1), bg, fg)
45
+ bg[:, :, 0] = alpha_l[:, :] * fg[:, :, 0] + (1 - alpha_l[:, :]) * bg[:, :, 0]
46
+ bg[:, :, 1] = alpha_l[:, :] * fg[:, :, 1] + (1 - alpha_l[:, :]) * bg[:, :, 1]
47
+ bg[:, :, 2] = alpha_l[:, :] * fg[:, :, 2] + (1 - alpha_l[:, :]) * bg[:, :, 2]
48
+ bg[:, :, 3] = alpha_l[:, :] * 255
49
+
50
+ del alpha_l, alpha_rgba, fg
51
+ return PIL.Image.fromarray(bg.cpu().numpy()).convert("RGBA")
52
+
53
+
54
+ def apply_mask(
55
+ image: PIL.Image.Image, mask: PIL.Image.Image, device="cpu"
56
+ ) -> PIL.Image.Image:
57
+ """
58
+ Applies mask to foreground.
59
+
60
+ Args:
61
+ device: Processing device.
62
+ image: Image with background.
63
+ mask: Alpha Channel mask for this image.
64
+
65
+ Returns:
66
+ Image without background, where mask was black.
67
+ """
68
+ background = PIL.Image.new("RGBA", image.size, color=(130, 130, 130, 0))
69
+ return composite(image, background, mask, device=device).convert("RGBA")
70
+
71
+
72
+ def extract_alpha_channel(image: PIL.Image.Image) -> PIL.Image.Image:
73
+ """
74
+ Extracts alpha channel from the RGBA image.
75
+
76
+ Args:
77
+ image: RGBA PIL image
78
+
79
+ Returns:
80
+ RGBA alpha channel image
81
+ """
82
+ alpha = image.split()[-1]
83
+ bg = PIL.Image.new("RGBA", image.size, (0, 0, 0, 255))
84
+ bg.paste(alpha, mask=alpha)
85
+ return bg.convert("RGBA")
carvekit/utils/models_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+
7
+ import random
8
+ import warnings
9
+ from typing import Union, Tuple, Any
10
+
11
+ import torch
12
+ from torch import autocast
13
+
14
+
15
+ class EmptyAutocast(object):
16
+ """
17
+ Empty class for disable any autocasting.
18
+ """
19
+
20
+ def __enter__(self):
21
+ return None
22
+
23
+ def __exit__(self, exc_type, exc_val, exc_tb):
24
+ return
25
+
26
+ def __call__(self, func):
27
+ return
28
+
29
+
30
+ def get_precision_autocast(
31
+ device="cpu", fp16=True, override_dtype=None
32
+ ) -> Union[
33
+ Tuple[EmptyAutocast, Union[torch.dtype, Any]],
34
+ Tuple[autocast, Union[torch.dtype, Any]],
35
+ ]:
36
+ """
37
+ Returns precision and autocast settings for given device and fp16 settings.
38
+ Args:
39
+ device: Device to get precision and autocast settings for.
40
+ fp16: Whether to use fp16 precision.
41
+ override_dtype: Override dtype for autocast.
42
+
43
+ Returns:
44
+ Autocast object, dtype
45
+ """
46
+ dtype = torch.float32
47
+ cache_enabled = None
48
+
49
+ if device == "cpu" and fp16:
50
+ warnings.warn('FP16 is not supported on CPU. Using FP32 instead.')
51
+ dtype = torch.float32
52
+
53
+ # TODO: Implement BFP16 on CPU. There are unexpected slowdowns on cpu on a clean environment.
54
+ # warnings.warn(
55
+ # "Accuracy BFP16 has experimental support on the CPU. "
56
+ # "This may result in an unexpected reduction in quality."
57
+ # )
58
+ # dtype = (
59
+ # torch.bfloat16
60
+ # ) # Using bfloat16 for CPU, since autocast is not supported for float16
61
+
62
+
63
+ if "cuda" in device and fp16:
64
+ dtype = torch.float16
65
+ cache_enabled = True
66
+
67
+ if override_dtype is not None:
68
+ dtype = override_dtype
69
+
70
+ if dtype == torch.float32 and device == "cpu":
71
+ return EmptyAutocast(), dtype
72
+
73
+ return (
74
+ torch.autocast(
75
+ device_type=device, dtype=dtype, enabled=True, cache_enabled=cache_enabled
76
+ ),
77
+ dtype,
78
+ )
79
+
80
+
81
+ def cast_network(network: torch.nn.Module, dtype: torch.dtype):
82
+ """Cast network to given dtype
83
+
84
+ Args:
85
+ network: Network to be casted
86
+ dtype: Dtype to cast network to
87
+ """
88
+ if dtype == torch.float16:
89
+ network.half()
90
+ elif dtype == torch.bfloat16:
91
+ network.bfloat16()
92
+ elif dtype == torch.float32:
93
+ network.float()
94
+ else:
95
+ raise ValueError(f"Unknown dtype {dtype}")
96
+
97
+
98
+ def fix_seed(seed=42):
99
+ """Sets fixed random seed
100
+
101
+ Args:
102
+ seed: Random seed to be set
103
+ """
104
+ random.seed(seed)
105
+ torch.manual_seed(seed)
106
+ if torch.cuda.is_available():
107
+ torch.cuda.manual_seed(seed)
108
+ torch.cuda.manual_seed_all(seed)
109
+ # noinspection PyUnresolvedReferences
110
+ torch.backends.cudnn.deterministic = True
111
+ # noinspection PyUnresolvedReferences
112
+ torch.backends.cudnn.benchmark = False
113
+ return True
114
+
115
+
116
+ def suppress_warnings():
117
+ # Suppress PyTorch 1.11.0 warning associated with changing order of args in nn.MaxPool2d layer,
118
+ # since source code is not affected by this issue and there aren't any other correct way to hide this message.
119
+ warnings.filterwarnings(
120
+ "ignore",
121
+ category=UserWarning,
122
+ message="Note that order of the arguments: ceil_mode and "
123
+ "return_indices will changeto match the args list "
124
+ "in nn.MaxPool2d in a future release.",
125
+ module="torch",
126
+ )
carvekit/utils/pool_utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Source url: https://github.com/OPHoperHPO/image-background-remove-tool
3
+ Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
4
+ License: Apache License 2.0
5
+ """
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ from typing import Any, Iterable
8
+
9
+
10
+ def thread_pool_processing(func: Any, data: Iterable, workers=18):
11
+ """
12
+ Passes all iterator data through the given function
13
+
14
+ Args:
15
+ workers: Count of workers.
16
+ func: function to pass data through
17
+ data: input iterator
18
+
19
+ Returns:
20
+ function return list
21
+
22
+ """
23
+ with ThreadPoolExecutor(workers) as p:
24
+ return list(p.map(func, data))
25
+
26
+
27
+ def batch_generator(iterable, n=1):
28
+ """
29
+ Splits any iterable into n-size packets
30
+
31
+ Args:
32
+ iterable: iterator
33
+ n: size of packets
34
+
35
+ Returns:
36
+ new n-size packet
37
+ """
38
+ it = len(iterable)
39
+ for ndx in range(0, it, n):
40
+ yield iterable[ndx : min(ndx + n, it)]
carvekit/web/__init__.py ADDED
File without changes
carvekit/web/app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import uvicorn
4
+ from fastapi import FastAPI
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from starlette.staticfiles import StaticFiles
7
+
8
+ from carvekit import version
9
+ from carvekit.web.deps import config
10
+ from carvekit.web.routers.api_router import api_router
11
+
12
+ app = FastAPI(title="CarveKit Web API", version=version)
13
+
14
+ app.add_middleware(
15
+ CORSMiddleware,
16
+ allow_origins=["*"],
17
+ allow_credentials=True,
18
+ allow_methods=["*"],
19
+ allow_headers=["*"],
20
+ )
21
+
22
+ app.include_router(api_router, prefix="/api")
23
+ app.mount(
24
+ "/",
25
+ StaticFiles(directory=Path(__file__).parent.joinpath("static"), html=True),
26
+ name="static",
27
+ )
28
+
29
+ if __name__ == "__main__":
30
+ uvicorn.run(app, host=config.host, port=config.port)
carvekit/web/deps.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from carvekit.web.schemas.config import WebAPIConfig
2
+ from carvekit.web.utils.init_utils import init_config
3
+ from carvekit.web.utils.task_queue import MLProcessor
4
+
5
+ config: WebAPIConfig = init_config()
6
+ ml_processor = MLProcessor(api_config=config)
carvekit/web/handlers/__init__.py ADDED
File without changes