jayparmr commited on
Commit
42ef134
·
1 Parent(s): fd77a63

Upload folder using huggingface_hub

Browse files
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  *.pyc
 
2
  .ipynb_checkpoints █
3
  env
4
  test.py
 
1
  *.pyc
2
+ .DS_Store
3
  .ipynb_checkpoints █
4
  env
5
  test.py
external/midas/.ipynb_checkpoints/api-checkpoint.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/isl-org/MiDaS
2
+
3
+ import cv2
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision.transforms import Compose
7
+
8
+ from .midas.dpt_depth import DPTDepthModel
9
+ from .midas.midas_net import MidasNet
10
+ from .midas.midas_net_custom import MidasNet_small
11
+ from .midas.transforms import Resize, NormalizeImage, PrepareForNet
12
+ from torchvision.datasets.utils import download_url
13
+ from pathlib import Path
14
+
15
+ ISL_PATHS = {
16
+ "dpt_large": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
17
+ "dpt_hybrid": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
18
+ "midas_v21": "",
19
+ "midas_v21_small": "",
20
+ }
21
+
22
+
23
+ def disabled_train(self, mode=True):
24
+ """Overwrite model.train with this function to make sure train/eval mode
25
+ does not change anymore."""
26
+ return self
27
+
28
+
29
+ def load_midas_transform(model_type):
30
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
31
+ # load transform only
32
+ if model_type == "dpt_large": # DPT-Large
33
+ net_w, net_h = 384, 384
34
+ resize_mode = "minimal"
35
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
36
+
37
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
38
+ net_w, net_h = 384, 384
39
+ resize_mode = "minimal"
40
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
41
+
42
+ elif model_type == "midas_v21":
43
+ net_w, net_h = 384, 384
44
+ resize_mode = "upper_bound"
45
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
46
+
47
+ elif model_type == "midas_v21_small":
48
+ net_w, net_h = 256, 256
49
+ resize_mode = "upper_bound"
50
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
51
+
52
+ else:
53
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
54
+
55
+ transform = Compose(
56
+ [
57
+ Resize(
58
+ net_w,
59
+ net_h,
60
+ resize_target=None,
61
+ keep_aspect_ratio=True,
62
+ ensure_multiple_of=32,
63
+ resize_method=resize_mode,
64
+ image_interpolation_method=cv2.INTER_CUBIC,
65
+ ),
66
+ normalization,
67
+ PrepareForNet(),
68
+ ]
69
+ )
70
+
71
+ return transform
72
+
73
+
74
+ def load_model(model_type):
75
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
76
+ # load network
77
+ model_path = ISL_PATHS[model_type]
78
+ download_url(model_path, "Intel-isl")
79
+ model_path = f"Intel-isl/{model_path.split('/')[-1]}"
80
+ if model_type == "dpt_large": # DPT-Large
81
+ model = DPTDepthModel(
82
+ path=model_path,
83
+ backbone="vitl16_384",
84
+ non_negative=True,
85
+ )
86
+ net_w, net_h = 384, 384
87
+ resize_mode = "minimal"
88
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
89
+
90
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
91
+ model = DPTDepthModel(
92
+ path=model_path,
93
+ backbone="vitb_rn50_384",
94
+ non_negative=True,
95
+ )
96
+ net_w, net_h = 384, 384
97
+ resize_mode = "minimal"
98
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
99
+
100
+ elif model_type == "midas_v21":
101
+ model = MidasNet(model_path, non_negative=True)
102
+ net_w, net_h = 384, 384
103
+ resize_mode = "upper_bound"
104
+ normalization = NormalizeImage(
105
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
106
+ )
107
+
108
+ elif model_type == "midas_v21_small":
109
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
110
+ non_negative=True, blocks={'expand': True})
111
+ net_w, net_h = 256, 256
112
+ resize_mode = "upper_bound"
113
+ normalization = NormalizeImage(
114
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
115
+ )
116
+
117
+ else:
118
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
119
+ assert False
120
+
121
+ transform = Compose(
122
+ [
123
+ Resize(
124
+ net_w,
125
+ net_h,
126
+ resize_target=None,
127
+ keep_aspect_ratio=True,
128
+ ensure_multiple_of=32,
129
+ resize_method=resize_mode,
130
+ image_interpolation_method=cv2.INTER_CUBIC,
131
+ ),
132
+ normalization,
133
+ PrepareForNet(),
134
+ ]
135
+ )
136
+
137
+ return model.eval(), transform
138
+
139
+
140
+ class MiDaSInference(nn.Module):
141
+ MODEL_TYPES_TORCH_HUB = [
142
+ "DPT_Large",
143
+ "DPT_Hybrid",
144
+ "MiDaS_small"
145
+ ]
146
+ MODEL_TYPES_ISL = [
147
+ "dpt_large",
148
+ "dpt_hybrid",
149
+ "midas_v21",
150
+ "midas_v21_small",
151
+ ]
152
+
153
+ def __init__(self, model_type):
154
+ super().__init__()
155
+ assert (model_type in self.MODEL_TYPES_ISL)
156
+ model, _ = load_model(model_type)
157
+ self.model = model
158
+ self.model.train = disabled_train
159
+
160
+ def forward(self, x):
161
+ with torch.no_grad():
162
+ prediction = self.model(x)
163
+ return prediction
164
+
external/midas/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+
5
+ from einops import rearrange
6
+ from .api import MiDaSInference
7
+
8
+ model = MiDaSInference(model_type="dpt_hybrid").cuda()
9
+
10
+
11
+ def apply_midas(input_image, a=np.pi * 2.0, bg_th=0.1):
12
+ assert input_image.ndim == 3
13
+ image_depth = input_image
14
+ with torch.no_grad():
15
+ image_depth = torch.from_numpy(image_depth).float().cuda()
16
+ image_depth = image_depth / 127.5 - 1.0
17
+ image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
18
+ depth = model(image_depth)[0]
19
+
20
+ depth_pt = depth.clone()
21
+ depth_pt -= torch.min(depth_pt)
22
+ depth_pt /= torch.max(depth_pt)
23
+ depth_pt = depth_pt.cpu().numpy()
24
+ depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
25
+
26
+ depth_np = depth.cpu().numpy()
27
+ x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
28
+ y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
29
+ z = np.ones_like(x) * a
30
+ x[depth_pt < bg_th] = 0
31
+ y[depth_pt < bg_th] = 0
32
+ normal = np.stack([x, y, z], axis=2)
33
+ normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
34
+ normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
35
+
36
+ return depth_image, normal_image
external/midas/api.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/isl-org/MiDaS
2
+
3
+ from pathlib import Path
4
+
5
+ import cv2
6
+ import torch
7
+ import torch.nn as nn
8
+ from torchvision.datasets.utils import download_url
9
+ from torchvision.transforms import Compose
10
+
11
+ from .midas.dpt_depth import DPTDepthModel
12
+ from .midas.midas_net import MidasNet
13
+ from .midas.midas_net_custom import MidasNet_small
14
+ from .midas.transforms import NormalizeImage, PrepareForNet, Resize
15
+
16
+ ISL_PATHS = {
17
+ "dpt_large": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
18
+ "dpt_hybrid": "https://github.com/isl-org/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
19
+ "midas_v21": "",
20
+ "midas_v21_small": "",
21
+ }
22
+
23
+
24
+ def disabled_train(self, mode=True):
25
+ """Overwrite model.train with this function to make sure train/eval mode
26
+ does not change anymore."""
27
+ return self
28
+
29
+
30
+ def load_midas_transform(model_type):
31
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
32
+ # load transform only
33
+ if model_type == "dpt_large": # DPT-Large
34
+ net_w, net_h = 384, 384
35
+ resize_mode = "minimal"
36
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
37
+
38
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
39
+ net_w, net_h = 384, 384
40
+ resize_mode = "minimal"
41
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
42
+
43
+ elif model_type == "midas_v21":
44
+ net_w, net_h = 384, 384
45
+ resize_mode = "upper_bound"
46
+ normalization = NormalizeImage(
47
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
48
+ )
49
+
50
+ elif model_type == "midas_v21_small":
51
+ net_w, net_h = 256, 256
52
+ resize_mode = "upper_bound"
53
+ normalization = NormalizeImage(
54
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
55
+ )
56
+
57
+ else:
58
+ assert (
59
+ False
60
+ ), f"model_type '{model_type}' not implemented, use: --model_type large"
61
+
62
+ transform = Compose(
63
+ [
64
+ Resize(
65
+ net_w,
66
+ net_h,
67
+ resize_target=None,
68
+ keep_aspect_ratio=True,
69
+ ensure_multiple_of=32,
70
+ resize_method=resize_mode,
71
+ image_interpolation_method=cv2.INTER_CUBIC,
72
+ ),
73
+ normalization,
74
+ PrepareForNet(),
75
+ ]
76
+ )
77
+
78
+ return transform
79
+
80
+
81
+ def load_model(model_type):
82
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
83
+ # load network
84
+ model_path = ISL_PATHS[model_type]
85
+ download_url(model_path, "~/.cache/Intel-isl")
86
+ model_path = f"{Path.home()}/.cache/Intel-isl/{model_path.split('/')[-1]}"
87
+ if model_type == "dpt_large": # DPT-Large
88
+ model = DPTDepthModel(
89
+ path=model_path,
90
+ backbone="vitl16_384",
91
+ non_negative=True,
92
+ )
93
+ net_w, net_h = 384, 384
94
+ resize_mode = "minimal"
95
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
96
+
97
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
98
+ model = DPTDepthModel(
99
+ path=model_path,
100
+ backbone="vitb_rn50_384",
101
+ non_negative=True,
102
+ )
103
+ net_w, net_h = 384, 384
104
+ resize_mode = "minimal"
105
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
106
+
107
+ elif model_type == "midas_v21":
108
+ model = MidasNet(model_path, non_negative=True)
109
+ net_w, net_h = 384, 384
110
+ resize_mode = "upper_bound"
111
+ normalization = NormalizeImage(
112
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
113
+ )
114
+
115
+ elif model_type == "midas_v21_small":
116
+ model = MidasNet_small(
117
+ model_path,
118
+ features=64,
119
+ backbone="efficientnet_lite3",
120
+ exportable=True,
121
+ non_negative=True,
122
+ blocks={"expand": True},
123
+ )
124
+ net_w, net_h = 256, 256
125
+ resize_mode = "upper_bound"
126
+ normalization = NormalizeImage(
127
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
128
+ )
129
+
130
+ else:
131
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
132
+ assert False
133
+
134
+ transform = Compose(
135
+ [
136
+ Resize(
137
+ net_w,
138
+ net_h,
139
+ resize_target=None,
140
+ keep_aspect_ratio=True,
141
+ ensure_multiple_of=32,
142
+ resize_method=resize_mode,
143
+ image_interpolation_method=cv2.INTER_CUBIC,
144
+ ),
145
+ normalization,
146
+ PrepareForNet(),
147
+ ]
148
+ )
149
+
150
+ return model.eval(), transform
151
+
152
+
153
+ class MiDaSInference(nn.Module):
154
+ MODEL_TYPES_TORCH_HUB = ["DPT_Large", "DPT_Hybrid", "MiDaS_small"]
155
+ MODEL_TYPES_ISL = [
156
+ "dpt_large",
157
+ "dpt_hybrid",
158
+ "midas_v21",
159
+ "midas_v21_small",
160
+ ]
161
+
162
+ def __init__(self, model_type):
163
+ super().__init__()
164
+ assert model_type in self.MODEL_TYPES_ISL
165
+ model, _ = load_model(model_type)
166
+ self.model = model
167
+ self.model.train = disabled_train
168
+
169
+ def forward(self, x):
170
+ with torch.no_grad():
171
+ prediction = self.model(x)
172
+ return prediction
external/midas/midas/__init__.py ADDED
File without changes
external/midas/midas/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device('cpu'))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)
external/midas/midas/blocks.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .vit import (
5
+ _make_pretrained_vitb_rn50_384,
6
+ _make_pretrained_vitl16_384,
7
+ _make_pretrained_vitb16_384,
8
+ forward_vit,
9
+ )
10
+
11
+ def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12
+ if backbone == "vitl16_384":
13
+ pretrained = _make_pretrained_vitl16_384(
14
+ use_pretrained, hooks=hooks, use_readout=use_readout
15
+ )
16
+ scratch = _make_scratch(
17
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
18
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
19
+ elif backbone == "vitb_rn50_384":
20
+ pretrained = _make_pretrained_vitb_rn50_384(
21
+ use_pretrained,
22
+ hooks=hooks,
23
+ use_vit_only=use_vit_only,
24
+ use_readout=use_readout,
25
+ )
26
+ scratch = _make_scratch(
27
+ [256, 512, 768, 768], features, groups=groups, expand=expand
28
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
29
+ elif backbone == "vitb16_384":
30
+ pretrained = _make_pretrained_vitb16_384(
31
+ use_pretrained, hooks=hooks, use_readout=use_readout
32
+ )
33
+ scratch = _make_scratch(
34
+ [96, 192, 384, 768], features, groups=groups, expand=expand
35
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
36
+ elif backbone == "resnext101_wsl":
37
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39
+ elif backbone == "efficientnet_lite3":
40
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42
+ else:
43
+ print(f"Backbone '{backbone}' not implemented")
44
+ assert False
45
+
46
+ return pretrained, scratch
47
+
48
+
49
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50
+ scratch = nn.Module()
51
+
52
+ out_shape1 = out_shape
53
+ out_shape2 = out_shape
54
+ out_shape3 = out_shape
55
+ out_shape4 = out_shape
56
+ if expand==True:
57
+ out_shape1 = out_shape
58
+ out_shape2 = out_shape*2
59
+ out_shape3 = out_shape*4
60
+ out_shape4 = out_shape*8
61
+
62
+ scratch.layer1_rn = nn.Conv2d(
63
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64
+ )
65
+ scratch.layer2_rn = nn.Conv2d(
66
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67
+ )
68
+ scratch.layer3_rn = nn.Conv2d(
69
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70
+ )
71
+ scratch.layer4_rn = nn.Conv2d(
72
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73
+ )
74
+
75
+ return scratch
76
+
77
+
78
+ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79
+ efficientnet = torch.hub.load(
80
+ "rwightman/gen-efficientnet-pytorch",
81
+ "tf_efficientnet_lite3",
82
+ pretrained=use_pretrained,
83
+ exportable=exportable
84
+ )
85
+ return _make_efficientnet_backbone(efficientnet)
86
+
87
+
88
+ def _make_efficientnet_backbone(effnet):
89
+ pretrained = nn.Module()
90
+
91
+ pretrained.layer1 = nn.Sequential(
92
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93
+ )
94
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97
+
98
+ return pretrained
99
+
100
+
101
+ def _make_resnet_backbone(resnet):
102
+ pretrained = nn.Module()
103
+ pretrained.layer1 = nn.Sequential(
104
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105
+ )
106
+
107
+ pretrained.layer2 = resnet.layer2
108
+ pretrained.layer3 = resnet.layer3
109
+ pretrained.layer4 = resnet.layer4
110
+
111
+ return pretrained
112
+
113
+
114
+ def _make_pretrained_resnext101_wsl(use_pretrained):
115
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116
+ return _make_resnet_backbone(resnet)
117
+
118
+
119
+
120
+ class Interpolate(nn.Module):
121
+ """Interpolation module.
122
+ """
123
+
124
+ def __init__(self, scale_factor, mode, align_corners=False):
125
+ """Init.
126
+
127
+ Args:
128
+ scale_factor (float): scaling
129
+ mode (str): interpolation mode
130
+ """
131
+ super(Interpolate, self).__init__()
132
+
133
+ self.interp = nn.functional.interpolate
134
+ self.scale_factor = scale_factor
135
+ self.mode = mode
136
+ self.align_corners = align_corners
137
+
138
+ def forward(self, x):
139
+ """Forward pass.
140
+
141
+ Args:
142
+ x (tensor): input
143
+
144
+ Returns:
145
+ tensor: interpolated data
146
+ """
147
+
148
+ x = self.interp(
149
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150
+ )
151
+
152
+ return x
153
+
154
+
155
+ class ResidualConvUnit(nn.Module):
156
+ """Residual convolution module.
157
+ """
158
+
159
+ def __init__(self, features):
160
+ """Init.
161
+
162
+ Args:
163
+ features (int): number of features
164
+ """
165
+ super().__init__()
166
+
167
+ self.conv1 = nn.Conv2d(
168
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
169
+ )
170
+
171
+ self.conv2 = nn.Conv2d(
172
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
173
+ )
174
+
175
+ self.relu = nn.ReLU(inplace=True)
176
+
177
+ def forward(self, x):
178
+ """Forward pass.
179
+
180
+ Args:
181
+ x (tensor): input
182
+
183
+ Returns:
184
+ tensor: output
185
+ """
186
+ out = self.relu(x)
187
+ out = self.conv1(out)
188
+ out = self.relu(out)
189
+ out = self.conv2(out)
190
+
191
+ return out + x
192
+
193
+
194
+ class FeatureFusionBlock(nn.Module):
195
+ """Feature fusion block.
196
+ """
197
+
198
+ def __init__(self, features):
199
+ """Init.
200
+
201
+ Args:
202
+ features (int): number of features
203
+ """
204
+ super(FeatureFusionBlock, self).__init__()
205
+
206
+ self.resConfUnit1 = ResidualConvUnit(features)
207
+ self.resConfUnit2 = ResidualConvUnit(features)
208
+
209
+ def forward(self, *xs):
210
+ """Forward pass.
211
+
212
+ Returns:
213
+ tensor: output
214
+ """
215
+ output = xs[0]
216
+
217
+ if len(xs) == 2:
218
+ output += self.resConfUnit1(xs[1])
219
+
220
+ output = self.resConfUnit2(output)
221
+
222
+ output = nn.functional.interpolate(
223
+ output, scale_factor=2, mode="bilinear", align_corners=True
224
+ )
225
+
226
+ return output
227
+
228
+
229
+
230
+
231
+ class ResidualConvUnit_custom(nn.Module):
232
+ """Residual convolution module.
233
+ """
234
+
235
+ def __init__(self, features, activation, bn):
236
+ """Init.
237
+
238
+ Args:
239
+ features (int): number of features
240
+ """
241
+ super().__init__()
242
+
243
+ self.bn = bn
244
+
245
+ self.groups=1
246
+
247
+ self.conv1 = nn.Conv2d(
248
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249
+ )
250
+
251
+ self.conv2 = nn.Conv2d(
252
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253
+ )
254
+
255
+ if self.bn==True:
256
+ self.bn1 = nn.BatchNorm2d(features)
257
+ self.bn2 = nn.BatchNorm2d(features)
258
+
259
+ self.activation = activation
260
+
261
+ self.skip_add = nn.quantized.FloatFunctional()
262
+
263
+ def forward(self, x):
264
+ """Forward pass.
265
+
266
+ Args:
267
+ x (tensor): input
268
+
269
+ Returns:
270
+ tensor: output
271
+ """
272
+
273
+ out = self.activation(x)
274
+ out = self.conv1(out)
275
+ if self.bn==True:
276
+ out = self.bn1(out)
277
+
278
+ out = self.activation(out)
279
+ out = self.conv2(out)
280
+ if self.bn==True:
281
+ out = self.bn2(out)
282
+
283
+ if self.groups > 1:
284
+ out = self.conv_merge(out)
285
+
286
+ return self.skip_add.add(out, x)
287
+
288
+ # return out + x
289
+
290
+
291
+ class FeatureFusionBlock_custom(nn.Module):
292
+ """Feature fusion block.
293
+ """
294
+
295
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296
+ """Init.
297
+
298
+ Args:
299
+ features (int): number of features
300
+ """
301
+ super(FeatureFusionBlock_custom, self).__init__()
302
+
303
+ self.deconv = deconv
304
+ self.align_corners = align_corners
305
+
306
+ self.groups=1
307
+
308
+ self.expand = expand
309
+ out_features = features
310
+ if self.expand==True:
311
+ out_features = features//2
312
+
313
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314
+
315
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317
+
318
+ self.skip_add = nn.quantized.FloatFunctional()
319
+
320
+ def forward(self, *xs):
321
+ """Forward pass.
322
+
323
+ Returns:
324
+ tensor: output
325
+ """
326
+ output = xs[0]
327
+
328
+ if len(xs) == 2:
329
+ res = self.resConfUnit1(xs[1])
330
+ output = self.skip_add.add(output, res)
331
+ # output += res
332
+
333
+ output = self.resConfUnit2(output)
334
+
335
+ output = nn.functional.interpolate(
336
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337
+ )
338
+
339
+ output = self.out_conv(output)
340
+
341
+ return output
342
+
external/midas/midas/dpt_depth.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .base_model import BaseModel
6
+ from .blocks import (
7
+ FeatureFusionBlock,
8
+ FeatureFusionBlock_custom,
9
+ Interpolate,
10
+ _make_encoder,
11
+ forward_vit,
12
+ )
13
+
14
+
15
+ def _make_fusion_block(features, use_bn):
16
+ return FeatureFusionBlock_custom(
17
+ features,
18
+ nn.ReLU(False),
19
+ deconv=False,
20
+ bn=use_bn,
21
+ expand=False,
22
+ align_corners=True,
23
+ )
24
+
25
+
26
+ class DPT(BaseModel):
27
+ def __init__(
28
+ self,
29
+ head,
30
+ features=256,
31
+ backbone="vitb_rn50_384",
32
+ readout="project",
33
+ channels_last=False,
34
+ use_bn=False,
35
+ ):
36
+
37
+ super(DPT, self).__init__()
38
+
39
+ self.channels_last = channels_last
40
+
41
+ hooks = {
42
+ "vitb_rn50_384": [0, 1, 8, 11],
43
+ "vitb16_384": [2, 5, 8, 11],
44
+ "vitl16_384": [5, 11, 17, 23],
45
+ }
46
+
47
+ # Instantiate backbone and reassemble blocks
48
+ self.pretrained, self.scratch = _make_encoder(
49
+ backbone,
50
+ features,
51
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
52
+ groups=1,
53
+ expand=False,
54
+ exportable=False,
55
+ hooks=hooks[backbone],
56
+ use_readout=readout,
57
+ )
58
+
59
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63
+
64
+ self.scratch.output_conv = head
65
+
66
+
67
+ def forward(self, x):
68
+ if self.channels_last == True:
69
+ x.contiguous(memory_format=torch.channels_last)
70
+
71
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72
+
73
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
74
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
75
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
76
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
77
+
78
+ path_4 = self.scratch.refinenet4(layer_4_rn)
79
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82
+
83
+ out = self.scratch.output_conv(path_1)
84
+
85
+ return out
86
+
87
+
88
+ class DPTDepthModel(DPT):
89
+ def __init__(self, path=None, non_negative=True, **kwargs):
90
+ features = kwargs["features"] if "features" in kwargs else 256
91
+
92
+ head = nn.Sequential(
93
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96
+ nn.ReLU(True),
97
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98
+ nn.ReLU(True) if non_negative else nn.Identity(),
99
+ nn.Identity(),
100
+ )
101
+
102
+ super().__init__(head, **kwargs)
103
+
104
+ if path is not None:
105
+ self.load(path)
106
+
107
+ def forward(self, x):
108
+ return super().forward(x).squeeze(dim=1)
109
+
external/midas/midas/midas_net.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=256, non_negative=True):
17
+ """Init.
18
+
19
+ Args:
20
+ path (str, optional): Path to saved model. Defaults to None.
21
+ features (int, optional): Number of features. Defaults to 256.
22
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23
+ """
24
+ print("Loading weights: ", path)
25
+
26
+ super(MidasNet, self).__init__()
27
+
28
+ use_pretrained = False if path is None else True
29
+
30
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31
+
32
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
33
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
34
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
35
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
36
+
37
+ self.scratch.output_conv = nn.Sequential(
38
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39
+ Interpolate(scale_factor=2, mode="bilinear"),
40
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41
+ nn.ReLU(True),
42
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43
+ nn.ReLU(True) if non_negative else nn.Identity(),
44
+ )
45
+
46
+ if path:
47
+ self.load(path)
48
+
49
+ def forward(self, x):
50
+ """Forward pass.
51
+
52
+ Args:
53
+ x (tensor): input data (image)
54
+
55
+ Returns:
56
+ tensor: depth
57
+ """
58
+
59
+ layer_1 = self.pretrained.layer1(x)
60
+ layer_2 = self.pretrained.layer2(layer_1)
61
+ layer_3 = self.pretrained.layer3(layer_2)
62
+ layer_4 = self.pretrained.layer4(layer_3)
63
+
64
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
65
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
66
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
67
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
68
+
69
+ path_4 = self.scratch.refinenet4(layer_4_rn)
70
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73
+
74
+ out = self.scratch.output_conv(path_1)
75
+
76
+ return torch.squeeze(out, dim=1)
external/midas/midas/midas_net_custom.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet_small(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17
+ blocks={'expand': True}):
18
+ """Init.
19
+
20
+ Args:
21
+ path (str, optional): Path to saved model. Defaults to None.
22
+ features (int, optional): Number of features. Defaults to 256.
23
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24
+ """
25
+ print("Loading weights: ", path)
26
+
27
+ super(MidasNet_small, self).__init__()
28
+
29
+ use_pretrained = False if path else True
30
+
31
+ self.channels_last = channels_last
32
+ self.blocks = blocks
33
+ self.backbone = backbone
34
+
35
+ self.groups = 1
36
+
37
+ features1=features
38
+ features2=features
39
+ features3=features
40
+ features4=features
41
+ self.expand = False
42
+ if "expand" in self.blocks and self.blocks['expand'] == True:
43
+ self.expand = True
44
+ features1=features
45
+ features2=features*2
46
+ features3=features*4
47
+ features4=features*8
48
+
49
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50
+
51
+ self.scratch.activation = nn.ReLU(False)
52
+
53
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57
+
58
+
59
+ self.scratch.output_conv = nn.Sequential(
60
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61
+ Interpolate(scale_factor=2, mode="bilinear"),
62
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63
+ self.scratch.activation,
64
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65
+ nn.ReLU(True) if non_negative else nn.Identity(),
66
+ nn.Identity(),
67
+ )
68
+
69
+ if path:
70
+ self.load(path)
71
+
72
+
73
+ def forward(self, x):
74
+ """Forward pass.
75
+
76
+ Args:
77
+ x (tensor): input data (image)
78
+
79
+ Returns:
80
+ tensor: depth
81
+ """
82
+ if self.channels_last==True:
83
+ print("self.channels_last = ", self.channels_last)
84
+ x.contiguous(memory_format=torch.channels_last)
85
+
86
+
87
+ layer_1 = self.pretrained.layer1(x)
88
+ layer_2 = self.pretrained.layer2(layer_1)
89
+ layer_3 = self.pretrained.layer3(layer_2)
90
+ layer_4 = self.pretrained.layer4(layer_3)
91
+
92
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
93
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
94
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
95
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
96
+
97
+
98
+ path_4 = self.scratch.refinenet4(layer_4_rn)
99
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102
+
103
+ out = self.scratch.output_conv(path_1)
104
+
105
+ return torch.squeeze(out, dim=1)
106
+
107
+
108
+
109
+ def fuse_model(m):
110
+ prev_previous_type = nn.Identity()
111
+ prev_previous_name = ''
112
+ previous_type = nn.Identity()
113
+ previous_name = ''
114
+ for name, module in m.named_modules():
115
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116
+ # print("FUSED ", prev_previous_name, previous_name, name)
117
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119
+ # print("FUSED ", prev_previous_name, previous_name)
120
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122
+ # print("FUSED ", previous_name, name)
123
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124
+
125
+ prev_previous_type = previous_type
126
+ prev_previous_name = previous_name
127
+ previous_type = type(module)
128
+ previous_name = name
external/midas/midas/transforms.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import math
4
+
5
+
6
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
8
+
9
+ Args:
10
+ sample (dict): sample
11
+ size (tuple): image size
12
+
13
+ Returns:
14
+ tuple: new size
15
+ """
16
+ shape = list(sample["disparity"].shape)
17
+
18
+ if shape[0] >= size[0] and shape[1] >= size[1]:
19
+ return sample
20
+
21
+ scale = [0, 0]
22
+ scale[0] = size[0] / shape[0]
23
+ scale[1] = size[1] / shape[1]
24
+
25
+ scale = max(scale)
26
+
27
+ shape[0] = math.ceil(scale * shape[0])
28
+ shape[1] = math.ceil(scale * shape[1])
29
+
30
+ # resize
31
+ sample["image"] = cv2.resize(
32
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33
+ )
34
+
35
+ sample["disparity"] = cv2.resize(
36
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37
+ )
38
+ sample["mask"] = cv2.resize(
39
+ sample["mask"].astype(np.float32),
40
+ tuple(shape[::-1]),
41
+ interpolation=cv2.INTER_NEAREST,
42
+ )
43
+ sample["mask"] = sample["mask"].astype(bool)
44
+
45
+ return tuple(shape)
46
+
47
+
48
+ class Resize(object):
49
+ """Resize sample to given size (width, height).
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ width,
55
+ height,
56
+ resize_target=True,
57
+ keep_aspect_ratio=False,
58
+ ensure_multiple_of=1,
59
+ resize_method="lower_bound",
60
+ image_interpolation_method=cv2.INTER_AREA,
61
+ ):
62
+ """Init.
63
+
64
+ Args:
65
+ width (int): desired output width
66
+ height (int): desired output height
67
+ resize_target (bool, optional):
68
+ True: Resize the full sample (image, mask, target).
69
+ False: Resize image only.
70
+ Defaults to True.
71
+ keep_aspect_ratio (bool, optional):
72
+ True: Keep the aspect ratio of the input sample.
73
+ Output sample might not have the given width and height, and
74
+ resize behaviour depends on the parameter 'resize_method'.
75
+ Defaults to False.
76
+ ensure_multiple_of (int, optional):
77
+ Output width and height is constrained to be multiple of this parameter.
78
+ Defaults to 1.
79
+ resize_method (str, optional):
80
+ "lower_bound": Output will be at least as large as the given size.
81
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83
+ Defaults to "lower_bound".
84
+ """
85
+ self.__width = width
86
+ self.__height = height
87
+
88
+ self.__resize_target = resize_target
89
+ self.__keep_aspect_ratio = keep_aspect_ratio
90
+ self.__multiple_of = ensure_multiple_of
91
+ self.__resize_method = resize_method
92
+ self.__image_interpolation_method = image_interpolation_method
93
+
94
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96
+
97
+ if max_val is not None and y > max_val:
98
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99
+
100
+ if y < min_val:
101
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102
+
103
+ return y
104
+
105
+ def get_size(self, width, height):
106
+ # determine new height and width
107
+ scale_height = self.__height / height
108
+ scale_width = self.__width / width
109
+
110
+ if self.__keep_aspect_ratio:
111
+ if self.__resize_method == "lower_bound":
112
+ # scale such that output size is lower bound
113
+ if scale_width > scale_height:
114
+ # fit width
115
+ scale_height = scale_width
116
+ else:
117
+ # fit height
118
+ scale_width = scale_height
119
+ elif self.__resize_method == "upper_bound":
120
+ # scale such that output size is upper bound
121
+ if scale_width < scale_height:
122
+ # fit width
123
+ scale_height = scale_width
124
+ else:
125
+ # fit height
126
+ scale_width = scale_height
127
+ elif self.__resize_method == "minimal":
128
+ # scale as least as possbile
129
+ if abs(1 - scale_width) < abs(1 - scale_height):
130
+ # fit width
131
+ scale_height = scale_width
132
+ else:
133
+ # fit height
134
+ scale_width = scale_height
135
+ else:
136
+ raise ValueError(
137
+ f"resize_method {self.__resize_method} not implemented"
138
+ )
139
+
140
+ if self.__resize_method == "lower_bound":
141
+ new_height = self.constrain_to_multiple_of(
142
+ scale_height * height, min_val=self.__height
143
+ )
144
+ new_width = self.constrain_to_multiple_of(
145
+ scale_width * width, min_val=self.__width
146
+ )
147
+ elif self.__resize_method == "upper_bound":
148
+ new_height = self.constrain_to_multiple_of(
149
+ scale_height * height, max_val=self.__height
150
+ )
151
+ new_width = self.constrain_to_multiple_of(
152
+ scale_width * width, max_val=self.__width
153
+ )
154
+ elif self.__resize_method == "minimal":
155
+ new_height = self.constrain_to_multiple_of(scale_height * height)
156
+ new_width = self.constrain_to_multiple_of(scale_width * width)
157
+ else:
158
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
159
+
160
+ return (new_width, new_height)
161
+
162
+ def __call__(self, sample):
163
+ width, height = self.get_size(
164
+ sample["image"].shape[1], sample["image"].shape[0]
165
+ )
166
+
167
+ # resize sample
168
+ sample["image"] = cv2.resize(
169
+ sample["image"],
170
+ (width, height),
171
+ interpolation=self.__image_interpolation_method,
172
+ )
173
+
174
+ if self.__resize_target:
175
+ if "disparity" in sample:
176
+ sample["disparity"] = cv2.resize(
177
+ sample["disparity"],
178
+ (width, height),
179
+ interpolation=cv2.INTER_NEAREST,
180
+ )
181
+
182
+ if "depth" in sample:
183
+ sample["depth"] = cv2.resize(
184
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185
+ )
186
+
187
+ sample["mask"] = cv2.resize(
188
+ sample["mask"].astype(np.float32),
189
+ (width, height),
190
+ interpolation=cv2.INTER_NEAREST,
191
+ )
192
+ sample["mask"] = sample["mask"].astype(bool)
193
+
194
+ return sample
195
+
196
+
197
+ class NormalizeImage(object):
198
+ """Normlize image by given mean and std.
199
+ """
200
+
201
+ def __init__(self, mean, std):
202
+ self.__mean = mean
203
+ self.__std = std
204
+
205
+ def __call__(self, sample):
206
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
207
+
208
+ return sample
209
+
210
+
211
+ class PrepareForNet(object):
212
+ """Prepare sample for usage as network input.
213
+ """
214
+
215
+ def __init__(self):
216
+ pass
217
+
218
+ def __call__(self, sample):
219
+ image = np.transpose(sample["image"], (2, 0, 1))
220
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221
+
222
+ if "mask" in sample:
223
+ sample["mask"] = sample["mask"].astype(np.float32)
224
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
225
+
226
+ if "disparity" in sample:
227
+ disparity = sample["disparity"].astype(np.float32)
228
+ sample["disparity"] = np.ascontiguousarray(disparity)
229
+
230
+ if "depth" in sample:
231
+ depth = sample["depth"].astype(np.float32)
232
+ sample["depth"] = np.ascontiguousarray(depth)
233
+
234
+ return sample
external/midas/midas/vit.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ import types
5
+ import math
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class Slice(nn.Module):
10
+ def __init__(self, start_index=1):
11
+ super(Slice, self).__init__()
12
+ self.start_index = start_index
13
+
14
+ def forward(self, x):
15
+ return x[:, self.start_index :]
16
+
17
+
18
+ class AddReadout(nn.Module):
19
+ def __init__(self, start_index=1):
20
+ super(AddReadout, self).__init__()
21
+ self.start_index = start_index
22
+
23
+ def forward(self, x):
24
+ if self.start_index == 2:
25
+ readout = (x[:, 0] + x[:, 1]) / 2
26
+ else:
27
+ readout = x[:, 0]
28
+ return x[:, self.start_index :] + readout.unsqueeze(1)
29
+
30
+
31
+ class ProjectReadout(nn.Module):
32
+ def __init__(self, in_features, start_index=1):
33
+ super(ProjectReadout, self).__init__()
34
+ self.start_index = start_index
35
+
36
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
37
+
38
+ def forward(self, x):
39
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
40
+ features = torch.cat((x[:, self.start_index :], readout), -1)
41
+
42
+ return self.project(features)
43
+
44
+
45
+ class Transpose(nn.Module):
46
+ def __init__(self, dim0, dim1):
47
+ super(Transpose, self).__init__()
48
+ self.dim0 = dim0
49
+ self.dim1 = dim1
50
+
51
+ def forward(self, x):
52
+ x = x.transpose(self.dim0, self.dim1)
53
+ return x
54
+
55
+
56
+ def forward_vit(pretrained, x):
57
+ b, c, h, w = x.shape
58
+
59
+ glob = pretrained.model.forward_flex(x)
60
+
61
+ layer_1 = pretrained.activations["1"]
62
+ layer_2 = pretrained.activations["2"]
63
+ layer_3 = pretrained.activations["3"]
64
+ layer_4 = pretrained.activations["4"]
65
+
66
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
67
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
68
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
69
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
70
+
71
+ unflatten = nn.Sequential(
72
+ nn.Unflatten(
73
+ 2,
74
+ torch.Size(
75
+ [
76
+ h // pretrained.model.patch_size[1],
77
+ w // pretrained.model.patch_size[0],
78
+ ]
79
+ ),
80
+ )
81
+ )
82
+
83
+ if layer_1.ndim == 3:
84
+ layer_1 = unflatten(layer_1)
85
+ if layer_2.ndim == 3:
86
+ layer_2 = unflatten(layer_2)
87
+ if layer_3.ndim == 3:
88
+ layer_3 = unflatten(layer_3)
89
+ if layer_4.ndim == 3:
90
+ layer_4 = unflatten(layer_4)
91
+
92
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
93
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
94
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
95
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
96
+
97
+ return layer_1, layer_2, layer_3, layer_4
98
+
99
+
100
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
101
+ posemb_tok, posemb_grid = (
102
+ posemb[:, : self.start_index],
103
+ posemb[0, self.start_index :],
104
+ )
105
+
106
+ gs_old = int(math.sqrt(len(posemb_grid)))
107
+
108
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
109
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
110
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
111
+
112
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
113
+
114
+ return posemb
115
+
116
+
117
+ def forward_flex(self, x):
118
+ b, c, h, w = x.shape
119
+
120
+ pos_embed = self._resize_pos_embed(
121
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
122
+ )
123
+
124
+ B = x.shape[0]
125
+
126
+ if hasattr(self.patch_embed, "backbone"):
127
+ x = self.patch_embed.backbone(x)
128
+ if isinstance(x, (list, tuple)):
129
+ x = x[-1] # last feature if backbone outputs list/tuple of features
130
+
131
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
132
+
133
+ if getattr(self, "dist_token", None) is not None:
134
+ cls_tokens = self.cls_token.expand(
135
+ B, -1, -1
136
+ ) # stole cls_tokens impl from Phil Wang, thanks
137
+ dist_token = self.dist_token.expand(B, -1, -1)
138
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
139
+ else:
140
+ cls_tokens = self.cls_token.expand(
141
+ B, -1, -1
142
+ ) # stole cls_tokens impl from Phil Wang, thanks
143
+ x = torch.cat((cls_tokens, x), dim=1)
144
+
145
+ x = x + pos_embed
146
+ x = self.pos_drop(x)
147
+
148
+ for blk in self.blocks:
149
+ x = blk(x)
150
+
151
+ x = self.norm(x)
152
+
153
+ return x
154
+
155
+
156
+ activations = {}
157
+
158
+
159
+ def get_activation(name):
160
+ def hook(model, input, output):
161
+ activations[name] = output
162
+
163
+ return hook
164
+
165
+
166
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
167
+ if use_readout == "ignore":
168
+ readout_oper = [Slice(start_index)] * len(features)
169
+ elif use_readout == "add":
170
+ readout_oper = [AddReadout(start_index)] * len(features)
171
+ elif use_readout == "project":
172
+ readout_oper = [
173
+ ProjectReadout(vit_features, start_index) for out_feat in features
174
+ ]
175
+ else:
176
+ assert (
177
+ False
178
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
179
+
180
+ return readout_oper
181
+
182
+
183
+ def _make_vit_b16_backbone(
184
+ model,
185
+ features=[96, 192, 384, 768],
186
+ size=[384, 384],
187
+ hooks=[2, 5, 8, 11],
188
+ vit_features=768,
189
+ use_readout="ignore",
190
+ start_index=1,
191
+ ):
192
+ pretrained = nn.Module()
193
+
194
+ pretrained.model = model
195
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
196
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
197
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
198
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
199
+
200
+ pretrained.activations = activations
201
+
202
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
203
+
204
+ # 32, 48, 136, 384
205
+ pretrained.act_postprocess1 = nn.Sequential(
206
+ readout_oper[0],
207
+ Transpose(1, 2),
208
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
209
+ nn.Conv2d(
210
+ in_channels=vit_features,
211
+ out_channels=features[0],
212
+ kernel_size=1,
213
+ stride=1,
214
+ padding=0,
215
+ ),
216
+ nn.ConvTranspose2d(
217
+ in_channels=features[0],
218
+ out_channels=features[0],
219
+ kernel_size=4,
220
+ stride=4,
221
+ padding=0,
222
+ bias=True,
223
+ dilation=1,
224
+ groups=1,
225
+ ),
226
+ )
227
+
228
+ pretrained.act_postprocess2 = nn.Sequential(
229
+ readout_oper[1],
230
+ Transpose(1, 2),
231
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
232
+ nn.Conv2d(
233
+ in_channels=vit_features,
234
+ out_channels=features[1],
235
+ kernel_size=1,
236
+ stride=1,
237
+ padding=0,
238
+ ),
239
+ nn.ConvTranspose2d(
240
+ in_channels=features[1],
241
+ out_channels=features[1],
242
+ kernel_size=2,
243
+ stride=2,
244
+ padding=0,
245
+ bias=True,
246
+ dilation=1,
247
+ groups=1,
248
+ ),
249
+ )
250
+
251
+ pretrained.act_postprocess3 = nn.Sequential(
252
+ readout_oper[2],
253
+ Transpose(1, 2),
254
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
255
+ nn.Conv2d(
256
+ in_channels=vit_features,
257
+ out_channels=features[2],
258
+ kernel_size=1,
259
+ stride=1,
260
+ padding=0,
261
+ ),
262
+ )
263
+
264
+ pretrained.act_postprocess4 = nn.Sequential(
265
+ readout_oper[3],
266
+ Transpose(1, 2),
267
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
268
+ nn.Conv2d(
269
+ in_channels=vit_features,
270
+ out_channels=features[3],
271
+ kernel_size=1,
272
+ stride=1,
273
+ padding=0,
274
+ ),
275
+ nn.Conv2d(
276
+ in_channels=features[3],
277
+ out_channels=features[3],
278
+ kernel_size=3,
279
+ stride=2,
280
+ padding=1,
281
+ ),
282
+ )
283
+
284
+ pretrained.model.start_index = start_index
285
+ pretrained.model.patch_size = [16, 16]
286
+
287
+ # We inject this function into the VisionTransformer instances so that
288
+ # we can use it with interpolated position embeddings without modifying the library source.
289
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
290
+ pretrained.model._resize_pos_embed = types.MethodType(
291
+ _resize_pos_embed, pretrained.model
292
+ )
293
+
294
+ return pretrained
295
+
296
+
297
+ def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
298
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
299
+
300
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
301
+ return _make_vit_b16_backbone(
302
+ model,
303
+ features=[256, 512, 1024, 1024],
304
+ hooks=hooks,
305
+ vit_features=1024,
306
+ use_readout=use_readout,
307
+ )
308
+
309
+
310
+ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
311
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
312
+
313
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
314
+ return _make_vit_b16_backbone(
315
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
316
+ )
317
+
318
+
319
+ def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
320
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
321
+
322
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
323
+ return _make_vit_b16_backbone(
324
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
325
+ )
326
+
327
+
328
+ def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
329
+ model = timm.create_model(
330
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
331
+ )
332
+
333
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
334
+ return _make_vit_b16_backbone(
335
+ model,
336
+ features=[96, 192, 384, 768],
337
+ hooks=hooks,
338
+ use_readout=use_readout,
339
+ start_index=2,
340
+ )
341
+
342
+
343
+ def _make_vit_b_rn50_backbone(
344
+ model,
345
+ features=[256, 512, 768, 768],
346
+ size=[384, 384],
347
+ hooks=[0, 1, 8, 11],
348
+ vit_features=768,
349
+ use_vit_only=False,
350
+ use_readout="ignore",
351
+ start_index=1,
352
+ ):
353
+ pretrained = nn.Module()
354
+
355
+ pretrained.model = model
356
+
357
+ if use_vit_only == True:
358
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
359
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
360
+ else:
361
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
362
+ get_activation("1")
363
+ )
364
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
365
+ get_activation("2")
366
+ )
367
+
368
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
369
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
370
+
371
+ pretrained.activations = activations
372
+
373
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
374
+
375
+ if use_vit_only == True:
376
+ pretrained.act_postprocess1 = nn.Sequential(
377
+ readout_oper[0],
378
+ Transpose(1, 2),
379
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
380
+ nn.Conv2d(
381
+ in_channels=vit_features,
382
+ out_channels=features[0],
383
+ kernel_size=1,
384
+ stride=1,
385
+ padding=0,
386
+ ),
387
+ nn.ConvTranspose2d(
388
+ in_channels=features[0],
389
+ out_channels=features[0],
390
+ kernel_size=4,
391
+ stride=4,
392
+ padding=0,
393
+ bias=True,
394
+ dilation=1,
395
+ groups=1,
396
+ ),
397
+ )
398
+
399
+ pretrained.act_postprocess2 = nn.Sequential(
400
+ readout_oper[1],
401
+ Transpose(1, 2),
402
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
403
+ nn.Conv2d(
404
+ in_channels=vit_features,
405
+ out_channels=features[1],
406
+ kernel_size=1,
407
+ stride=1,
408
+ padding=0,
409
+ ),
410
+ nn.ConvTranspose2d(
411
+ in_channels=features[1],
412
+ out_channels=features[1],
413
+ kernel_size=2,
414
+ stride=2,
415
+ padding=0,
416
+ bias=True,
417
+ dilation=1,
418
+ groups=1,
419
+ ),
420
+ )
421
+ else:
422
+ pretrained.act_postprocess1 = nn.Sequential(
423
+ nn.Identity(), nn.Identity(), nn.Identity()
424
+ )
425
+ pretrained.act_postprocess2 = nn.Sequential(
426
+ nn.Identity(), nn.Identity(), nn.Identity()
427
+ )
428
+
429
+ pretrained.act_postprocess3 = nn.Sequential(
430
+ readout_oper[2],
431
+ Transpose(1, 2),
432
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
433
+ nn.Conv2d(
434
+ in_channels=vit_features,
435
+ out_channels=features[2],
436
+ kernel_size=1,
437
+ stride=1,
438
+ padding=0,
439
+ ),
440
+ )
441
+
442
+ pretrained.act_postprocess4 = nn.Sequential(
443
+ readout_oper[3],
444
+ Transpose(1, 2),
445
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
446
+ nn.Conv2d(
447
+ in_channels=vit_features,
448
+ out_channels=features[3],
449
+ kernel_size=1,
450
+ stride=1,
451
+ padding=0,
452
+ ),
453
+ nn.Conv2d(
454
+ in_channels=features[3],
455
+ out_channels=features[3],
456
+ kernel_size=3,
457
+ stride=2,
458
+ padding=1,
459
+ ),
460
+ )
461
+
462
+ pretrained.model.start_index = start_index
463
+ pretrained.model.patch_size = [16, 16]
464
+
465
+ # We inject this function into the VisionTransformer instances so that
466
+ # we can use it with interpolated position embeddings without modifying the library source.
467
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
468
+
469
+ # We inject this function into the VisionTransformer instances so that
470
+ # we can use it with interpolated position embeddings without modifying the library source.
471
+ pretrained.model._resize_pos_embed = types.MethodType(
472
+ _resize_pos_embed, pretrained.model
473
+ )
474
+
475
+ return pretrained
476
+
477
+
478
+ def _make_pretrained_vitb_rn50_384(
479
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
480
+ ):
481
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
482
+
483
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
484
+ return _make_vit_b_rn50_backbone(
485
+ model,
486
+ features=[256, 512, 768, 768],
487
+ size=[384, 384],
488
+ hooks=hooks,
489
+ use_vit_only=use_vit_only,
490
+ use_readout=use_readout,
491
+ )
external/midas/utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for monoDepth."""
2
+ import sys
3
+ import re
4
+ import numpy as np
5
+ import cv2
6
+ import torch
7
+
8
+
9
+ def read_pfm(path):
10
+ """Read pfm file.
11
+
12
+ Args:
13
+ path (str): path to file
14
+
15
+ Returns:
16
+ tuple: (data, scale)
17
+ """
18
+ with open(path, "rb") as file:
19
+
20
+ color = None
21
+ width = None
22
+ height = None
23
+ scale = None
24
+ endian = None
25
+
26
+ header = file.readline().rstrip()
27
+ if header.decode("ascii") == "PF":
28
+ color = True
29
+ elif header.decode("ascii") == "Pf":
30
+ color = False
31
+ else:
32
+ raise Exception("Not a PFM file: " + path)
33
+
34
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35
+ if dim_match:
36
+ width, height = list(map(int, dim_match.groups()))
37
+ else:
38
+ raise Exception("Malformed PFM header.")
39
+
40
+ scale = float(file.readline().decode("ascii").rstrip())
41
+ if scale < 0:
42
+ # little-endian
43
+ endian = "<"
44
+ scale = -scale
45
+ else:
46
+ # big-endian
47
+ endian = ">"
48
+
49
+ data = np.fromfile(file, endian + "f")
50
+ shape = (height, width, 3) if color else (height, width)
51
+
52
+ data = np.reshape(data, shape)
53
+ data = np.flipud(data)
54
+
55
+ return data, scale
56
+
57
+
58
+ def write_pfm(path, image, scale=1):
59
+ """Write pfm file.
60
+
61
+ Args:
62
+ path (str): pathto file
63
+ image (array): data
64
+ scale (int, optional): Scale. Defaults to 1.
65
+ """
66
+
67
+ with open(path, "wb") as file:
68
+ color = None
69
+
70
+ if image.dtype.name != "float32":
71
+ raise Exception("Image dtype must be float32.")
72
+
73
+ image = np.flipud(image)
74
+
75
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
76
+ color = True
77
+ elif (
78
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79
+ ): # greyscale
80
+ color = False
81
+ else:
82
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83
+
84
+ file.write("PF\n" if color else "Pf\n".encode())
85
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86
+
87
+ endian = image.dtype.byteorder
88
+
89
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
90
+ scale = -scale
91
+
92
+ file.write("%f\n".encode() % scale)
93
+
94
+ image.tofile(file)
95
+
96
+
97
+ def read_image(path):
98
+ """Read image and output RGB image (0-1).
99
+
100
+ Args:
101
+ path (str): path to file
102
+
103
+ Returns:
104
+ array: RGB image (0-1)
105
+ """
106
+ img = cv2.imread(path)
107
+
108
+ if img.ndim == 2:
109
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110
+
111
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112
+
113
+ return img
114
+
115
+
116
+ def resize_image(img):
117
+ """Resize image and make it fit for network.
118
+
119
+ Args:
120
+ img (array): image
121
+
122
+ Returns:
123
+ tensor: data ready for network
124
+ """
125
+ height_orig = img.shape[0]
126
+ width_orig = img.shape[1]
127
+
128
+ if width_orig > height_orig:
129
+ scale = width_orig / 384
130
+ else:
131
+ scale = height_orig / 384
132
+
133
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135
+
136
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137
+
138
+ img_resized = (
139
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140
+ )
141
+ img_resized = img_resized.unsqueeze(0)
142
+
143
+ return img_resized
144
+
145
+
146
+ def resize_depth(depth, width, height):
147
+ """Resize depth map and bring to CPU (numpy).
148
+
149
+ Args:
150
+ depth (tensor): depth
151
+ width (int): image width
152
+ height (int): image height
153
+
154
+ Returns:
155
+ array: processed depth
156
+ """
157
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158
+
159
+ depth_resized = cv2.resize(
160
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161
+ )
162
+
163
+ return depth_resized
164
+
165
+ def write_depth(path, depth, bits=1):
166
+ """Write depth map to pfm and png file.
167
+
168
+ Args:
169
+ path (str): filepath without extension
170
+ depth (array): depth
171
+ """
172
+ write_pfm(path + ".pfm", depth.astype(np.float32))
173
+
174
+ depth_min = depth.min()
175
+ depth_max = depth.max()
176
+
177
+ max_val = (2**(8*bits))-1
178
+
179
+ if depth_max - depth_min > np.finfo("float").eps:
180
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
181
+ else:
182
+ out = np.zeros(depth.shape, dtype=depth.type)
183
+
184
+ if bits == 1:
185
+ cv2.imwrite(path + ".png", out.astype("uint8"))
186
+ elif bits == 2:
187
+ cv2.imwrite(path + ".png", out.astype("uint16"))
188
+
189
+ return
inference.py CHANGED
@@ -14,6 +14,7 @@ from internals.pipelines.img_to_text import Image2Text
14
  from internals.pipelines.inpainter import InPainter
15
  from internals.pipelines.pose_detector import PoseDetector
16
  from internals.pipelines.prompt_modifier import PromptModifier
 
17
  from internals.pipelines.safety_checker import SafetyChecker
18
  from internals.util.args import apply_style_args
19
  from internals.util.avatar import Avatar
@@ -41,6 +42,7 @@ inpainter = InPainter()
41
  high_res = HighRes()
42
  img2text = Image2Text()
43
  img_classifier = ImageClassifier()
 
44
  controlnet = ControlNet()
45
  lora_style = LoraStyle()
46
  text2img_pipe = Text2Img()
@@ -84,7 +86,9 @@ def canny(task: Task):
84
  controlnet.load_canny()
85
 
86
  # pipe2 is used for canny and pose
87
- lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
 
 
88
  lora_patcher.patch()
89
 
90
  images, has_nsfw = controlnet.process_canny(
@@ -170,7 +174,9 @@ def scribble(task: Task):
170
 
171
  controlnet.load_scribble()
172
 
173
- lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
 
 
174
  lora_patcher.patch()
175
 
176
  images, has_nsfw = controlnet.process_scribble(
@@ -214,7 +220,9 @@ def linearart(task: Task):
214
 
215
  controlnet.load_linearart()
216
 
217
- lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
 
 
218
  lora_patcher.patch()
219
 
220
  images, has_nsfw = controlnet.process_linearart(
@@ -259,10 +267,13 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
259
  controlnet.load_pose()
260
 
261
  # pipe2 is used for canny and pose
262
- lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
 
 
263
  lora_patcher.patch()
264
 
265
  if not task.get_pose_estimation():
 
266
  pose = download_image(task.get_imageUrl()).resize(
267
  (task.get_width(), task.get_height())
268
  )
@@ -278,9 +289,15 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
278
  else:
279
  poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences
280
 
 
 
 
 
 
281
  images, has_nsfw = controlnet.process_pose(
282
  prompt=prompt,
283
  image=poses,
 
284
  seed=task.get_seed(),
285
  steps=task.get_steps(),
286
  negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
@@ -299,8 +316,8 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
299
  steps=task.get_steps(),
300
  )
301
 
302
- pose_output_key = "crecoAI/{}_pose.png".format(task.get_taskId())
303
- upload_image(poses[0], pose_output_key)
304
 
305
  generated_image_urls = upload_images(images, s3_outkey, task.get_taskId())
306
 
@@ -322,7 +339,9 @@ def text2img(task: Task):
322
 
323
  width, height = get_intermediate_dimension(task)
324
 
325
- lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style())
 
 
326
  lora_patcher.patch()
327
 
328
  torch.manual_seed(task.get_seed())
@@ -366,7 +385,9 @@ def img2img(task: Task):
366
 
367
  width, height = get_intermediate_dimension(task)
368
 
369
- lora_patcher = lora_style.get_patcher(img2img_pipe.pipe, task.get_style())
 
 
370
  lora_patcher.patch()
371
 
372
  torch.manual_seed(task.get_seed())
@@ -427,6 +448,42 @@ def inpaint(task: Task):
427
  return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
428
 
429
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  def load_model_by_task(task: Task):
431
  if (
432
  task.get_type()
@@ -444,6 +501,8 @@ def load_model_by_task(task: Task):
444
 
445
  safety_checker.apply(text2img_pipe)
446
  safety_checker.apply(img2img_pipe)
 
 
447
  else:
448
  if task.get_type() == TaskType.TILE_UPSCALE:
449
  controlnet.load_tile_upscaler()
@@ -522,6 +581,8 @@ def predict_fn(data, pipe):
522
  return scribble(task)
523
  elif task_type == TaskType.LINEARART:
524
  return linearart(task)
 
 
525
  elif task_type == TaskType.SYSTEM_CMD:
526
  os.system(task.get_prompt())
527
  else:
 
14
  from internals.pipelines.inpainter import InPainter
15
  from internals.pipelines.pose_detector import PoseDetector
16
  from internals.pipelines.prompt_modifier import PromptModifier
17
+ from internals.pipelines.replace_background import ReplaceBackground
18
  from internals.pipelines.safety_checker import SafetyChecker
19
  from internals.util.args import apply_style_args
20
  from internals.util.avatar import Avatar
 
42
  high_res = HighRes()
43
  img2text = Image2Text()
44
  img_classifier = ImageClassifier()
45
+ replace_background = ReplaceBackground()
46
  controlnet = ControlNet()
47
  lora_style = LoraStyle()
48
  text2img_pipe = Text2Img()
 
86
  controlnet.load_canny()
87
 
88
  # pipe2 is used for canny and pose
89
+ lora_patcher = lora_style.get_patcher(
90
+ [controlnet.pipe2, high_res.pipe], task.get_style()
91
+ )
92
  lora_patcher.patch()
93
 
94
  images, has_nsfw = controlnet.process_canny(
 
174
 
175
  controlnet.load_scribble()
176
 
177
+ lora_patcher = lora_style.get_patcher(
178
+ [controlnet.pipe2, high_res.pipe], task.get_style()
179
+ )
180
  lora_patcher.patch()
181
 
182
  images, has_nsfw = controlnet.process_scribble(
 
220
 
221
  controlnet.load_linearart()
222
 
223
+ lora_patcher = lora_style.get_patcher(
224
+ [controlnet.pipe2, high_res.pipe], task.get_style()
225
+ )
226
  lora_patcher.patch()
227
 
228
  images, has_nsfw = controlnet.process_linearart(
 
267
  controlnet.load_pose()
268
 
269
  # pipe2 is used for canny and pose
270
+ lora_patcher = lora_style.get_patcher(
271
+ [controlnet.pipe2, high_res.pipe], task.get_style()
272
+ )
273
  lora_patcher.patch()
274
 
275
  if not task.get_pose_estimation():
276
+ print("Not detecting pose")
277
  pose = download_image(task.get_imageUrl()).resize(
278
  (task.get_width(), task.get_height())
279
  )
 
289
  else:
290
  poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences
291
 
292
+ src_image = download_image(task.get_auxilary_imageUrl()).resize(
293
+ (task.get_width(), task.get_height())
294
+ )
295
+ condition_image = ControlNet.linearart_condition_image(src_image)
296
+
297
  images, has_nsfw = controlnet.process_pose(
298
  prompt=prompt,
299
  image=poses,
300
+ condition_image=[condition_image] * num_return_sequences,
301
  seed=task.get_seed(),
302
  steps=task.get_steps(),
303
  negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
 
316
  steps=task.get_steps(),
317
  )
318
 
319
+ upload_image(poses[0], "crecoAI/{}_pose.png".format(task.get_taskId()))
320
+ upload_image(condition_image, "crecoAI/{}_condition.png".format(task.get_taskId()))
321
 
322
  generated_image_urls = upload_images(images, s3_outkey, task.get_taskId())
323
 
 
339
 
340
  width, height = get_intermediate_dimension(task)
341
 
342
+ lora_patcher = lora_style.get_patcher(
343
+ [text2img_pipe.pipe, high_res.pipe], task.get_style()
344
+ )
345
  lora_patcher.patch()
346
 
347
  torch.manual_seed(task.get_seed())
 
385
 
386
  width, height = get_intermediate_dimension(task)
387
 
388
+ lora_patcher = lora_style.get_patcher(
389
+ [img2img_pipe.pipe, high_res.pipe], task.get_style()
390
+ )
391
  lora_patcher.patch()
392
 
393
  torch.manual_seed(task.get_seed())
 
448
  return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
449
 
450
 
451
+ @update_db
452
+ @slack.auto_send_alert
453
+ def replace_bg(task: Task):
454
+ prompt = task.get_prompt()
455
+ if task.is_prompt_engineering():
456
+ prompt = prompt_modifier.modify(prompt)
457
+ else:
458
+ prompt = [prompt] * num_return_sequences
459
+
460
+ lora_patcher = lora_style.get_patcher(replace_background.pipe, task.get_style())
461
+ lora_patcher.patch()
462
+
463
+ images, has_nsfw = replace_background.replace(
464
+ image=task.get_imageUrl(),
465
+ prompt=prompt,
466
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
467
+ seed=task.get_seed(),
468
+ width=task.get_width(),
469
+ height=task.get_height(),
470
+ steps=task.get_steps(),
471
+ resize_dimension=task.get_resize_dimension(),
472
+ product_scale_width=task.get_image_scale(),
473
+ conditioning_scale=task.rbg_controlnet_conditioning_scale(),
474
+ )
475
+
476
+ generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
477
+
478
+ lora_patcher.cleanup()
479
+
480
+ return {
481
+ "modified_prompts": prompt,
482
+ "generated_image_urls": generated_image_urls,
483
+ "has_nsfw": has_nsfw,
484
+ }
485
+
486
+
487
  def load_model_by_task(task: Task):
488
  if (
489
  task.get_type()
 
501
 
502
  safety_checker.apply(text2img_pipe)
503
  safety_checker.apply(img2img_pipe)
504
+ elif task.get_type() == TaskType.REPLACE_BG:
505
+ replace_background.load(controlnet=controlnet)
506
  else:
507
  if task.get_type() == TaskType.TILE_UPSCALE:
508
  controlnet.load_tile_upscaler()
 
581
  return scribble(task)
582
  elif task_type == TaskType.LINEARART:
583
  return linearart(task)
584
+ elif task_type == TaskType.REPLACE_BG:
585
+ return replace_bg(task)
586
  elif task_type == TaskType.SYSTEM_CMD:
587
  os.system(task.get_prompt())
588
  else:
internals/data/dataAccessor.py CHANGED
@@ -70,8 +70,8 @@ def getStyles() -> Optional[Dict]:
70
  except requests.exceptions.Timeout:
71
  print("Request timed out while fetching styles")
72
  except requests.exceptions.RequestException as e:
73
- raise e
74
  print(f"Error while fetching styles: {e}")
 
75
  return None
76
 
77
 
 
70
  except requests.exceptions.Timeout:
71
  print("Request timed out while fetching styles")
72
  except requests.exceptions.RequestException as e:
 
73
  print(f"Error while fetching styles: {e}")
74
+ raise e
75
  return None
76
 
77
 
internals/data/task.py CHANGED
@@ -47,6 +47,9 @@ class Task:
47
  def get_imageUrl(self) -> str:
48
  return self.__data.get("imageUrl", None)
49
 
 
 
 
50
  def get_prompt(self) -> str:
51
  return self.__data.get("prompt", "")
52
 
 
47
  def get_imageUrl(self) -> str:
48
  return self.__data.get("imageUrl", None)
49
 
50
+ def get_auxilary_imageUrl(self) -> str:
51
+ return self.__data.get("aux_imageUrl", None)
52
+
53
  def get_prompt(self) -> str:
54
  return self.__data.get("prompt", "")
55
 
internals/pipelines/controlnets.py CHANGED
@@ -4,17 +4,27 @@ import cv2
4
  import numpy as np
5
  import torch
6
  from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
7
- from diffusers import (ControlNetModel, DiffusionPipeline,
8
- StableDiffusionControlNetPipeline,
9
- UniPCMultistepScheduler)
 
 
 
 
 
 
10
  from PIL import Image
11
  from torch.nn import Linear
12
  from tqdm import gui
 
13
 
 
 
14
  from internals.data.result import Result
15
  from internals.pipelines.commons import AbstractPipeline
16
- from internals.pipelines.tileUpscalePipeline import \
17
- StableDiffusionControlNetImg2ImgPipeline
 
18
  from internals.util.cache import clear_cuda_and_gc
19
  from internals.util.commons import download_image
20
  from internals.util.config import get_hf_cache_dir, get_hf_token, get_model_dir
@@ -25,10 +35,11 @@ class ControlNet(AbstractPipeline):
25
  __loaded = False
26
 
27
  def load(self):
 
28
  if self.__loaded:
29
  return
30
 
31
- if not self.controlnet:
32
  self.load_pose()
33
 
34
  # controlnet pipeline for tile upscaler
@@ -79,15 +90,20 @@ class ControlNet(AbstractPipeline):
79
  torch_dtype=torch.float16,
80
  cache_dir=get_hf_cache_dir(),
81
  ).to("cuda")
 
 
 
 
 
82
  self.__current_task_name = "pose"
83
- self.controlnet = pose
84
 
85
  self.load()
86
 
87
  if hasattr(self, "pipe"):
88
- self.pipe.controlnet = pose
89
  if hasattr(self, "pipe2"):
90
- self.pipe2.controlnet = pose
91
  clear_cuda_and_gc()
92
 
93
  def load_tile_upscaler(self):
@@ -195,6 +211,7 @@ class ControlNet(AbstractPipeline):
195
  self,
196
  prompt: List[str],
197
  image: List[Image.Image],
 
198
  seed: int,
199
  steps: int,
200
  guidance_scale: float,
@@ -208,14 +225,15 @@ class ControlNet(AbstractPipeline):
208
  torch.manual_seed(seed)
209
 
210
  result = self.pipe2.__call__(
211
- prompt=prompt,
212
- image=image,
213
- num_images_per_prompt=1,
214
  num_inference_steps=steps,
215
- negative_prompt=negative_prompt,
216
  guidance_scale=guidance_scale,
217
  height=height,
218
  width=width,
 
219
  )
220
  return Result.from_result(result)
221
 
@@ -337,6 +355,15 @@ class ControlNet(AbstractPipeline):
337
  image = processor.__call__(input_image=image)
338
  return image
339
 
 
 
 
 
 
 
 
 
 
340
  def __canny_detect_edge(self, image: Image.Image) -> Image.Image:
341
  image_array = np.array(image)
342
 
 
4
  import numpy as np
5
  import torch
6
  from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
7
+ from diffusers import (
8
+ ControlNetModel,
9
+ DiffusionPipeline,
10
+ StableDiffusionControlNetPipeline,
11
+ UniPCMultistepScheduler,
12
+ )
13
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import (
14
+ MultiControlNetModel,
15
+ )
16
  from PIL import Image
17
  from torch.nn import Linear
18
  from tqdm import gui
19
+ from transformers import pipeline
20
 
21
+ import internals.util.image as ImageUtil
22
+ from external.midas import apply_midas
23
  from internals.data.result import Result
24
  from internals.pipelines.commons import AbstractPipeline
25
+ from internals.pipelines.tileUpscalePipeline import (
26
+ StableDiffusionControlNetImg2ImgPipeline,
27
+ )
28
  from internals.util.cache import clear_cuda_and_gc
29
  from internals.util.commons import download_image
30
  from internals.util.config import get_hf_cache_dir, get_hf_token, get_model_dir
 
35
  __loaded = False
36
 
37
  def load(self):
38
+ "Should not be called externally"
39
  if self.__loaded:
40
  return
41
 
42
+ if not hasattr(self, "controlnet"):
43
  self.load_pose()
44
 
45
  # controlnet pipeline for tile upscaler
 
90
  torch_dtype=torch.float16,
91
  cache_dir=get_hf_cache_dir(),
92
  ).to("cuda")
93
+ # lineart = ControlNetModel.from_pretrained(
94
+ # "ControlNet-1-1-preview/control_v11p_sd15_lineart",
95
+ # torch_dtype=torch.float16,
96
+ # cache_dir=get_hf_cache_dir(),
97
+ # ).to("cuda")
98
  self.__current_task_name = "pose"
99
+ self.controlnet = MultiControlNetModel([pose]).to("cuda")
100
 
101
  self.load()
102
 
103
  if hasattr(self, "pipe"):
104
+ self.pipe.controlnet = self.controlnet
105
  if hasattr(self, "pipe2"):
106
+ self.pipe2.controlnet = self.controlnet
107
  clear_cuda_and_gc()
108
 
109
  def load_tile_upscaler(self):
 
211
  self,
212
  prompt: List[str],
213
  image: List[Image.Image],
214
+ condition_image: List[Image.Image],
215
  seed: int,
216
  steps: int,
217
  guidance_scale: float,
 
225
  torch.manual_seed(seed)
226
 
227
  result = self.pipe2.__call__(
228
+ prompt=prompt[0],
229
+ image=[image[0]],
230
+ num_images_per_prompt=4,
231
  num_inference_steps=steps,
232
+ negative_prompt=negative_prompt[0],
233
  guidance_scale=guidance_scale,
234
  height=height,
235
  width=width,
236
+ controlnet_conditioning_scale=[1.0],
237
  )
238
  return Result.from_result(result)
239
 
 
355
  image = processor.__call__(input_image=image)
356
  return image
357
 
358
+ @staticmethod
359
+ def depth_image(image: Image.Image) -> Image.Image:
360
+ depth = np.array(image)
361
+ depth = ImageUtil.HWC3(depth)
362
+ depth, _ = apply_midas(depth)
363
+ depth = ImageUtil.HWC3(depth)
364
+ depth = Image.fromarray(depth)
365
+ return depth
366
+
367
  def __canny_detect_edge(self, image: Image.Image) -> Image.Image:
368
  image_array = np.array(image)
369
 
internals/pipelines/remove_background.py CHANGED
@@ -7,6 +7,7 @@ import torch.nn.functional as F
7
  from PIL import Image
8
  from rembg import remove
9
 
 
10
  from carvekit.api.high import HiInterface
11
  from internals.util.commons import download_image, read_url
12
 
@@ -40,6 +41,10 @@ class RemoveBackgroundV2:
40
  if type(image) is str:
41
  image = download_image(image)
42
 
 
 
 
 
43
  image.save(img_path)
44
  images_without_background = self.interface([img_path])
45
  out = images_without_background[0]
 
7
  from PIL import Image
8
  from rembg import remove
9
 
10
+ import internals.util.image as ImageUtil
11
  from carvekit.api.high import HiInterface
12
  from internals.util.commons import download_image, read_url
13
 
 
41
  if type(image) is str:
42
  image = download_image(image)
43
 
44
+ w, h = image.size
45
+ if max(w, h) > 1536:
46
+ image = ImageUtil.resize_image(image, dimension=1024)
47
+
48
  image.save(img_path)
49
  images_without_background = self.interface([img_path])
50
  out = images_without_background[0]
internals/pipelines/replace_background.py CHANGED
@@ -1,5 +1,5 @@
1
  from io import BytesIO
2
- from typing import List, Union
3
 
4
  import torch
5
  from diffusers import (
@@ -17,31 +17,55 @@ from internals.pipelines.controlnets import ControlNet
17
  from internals.pipelines.remove_background import RemoveBackgroundV2
18
  from internals.pipelines.upscaler import Upscaler
19
  from internals.util.commons import download_image
20
- from internals.util.config import get_hf_cache_dir
21
 
22
 
23
  class ReplaceBackground(AbstractPipeline):
24
- def load(self, upscaler: Upscaler, remove_background: RemoveBackgroundV2):
25
- controlnet = ControlNetModel.from_pretrained(
 
 
 
 
 
 
 
 
 
26
  "lllyasviel/control_v11p_sd15_lineart",
27
  torch_dtype=torch.float16,
28
  cache_dir=get_hf_cache_dir(),
29
  ).to("cuda")
30
- pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
31
- "runwayml/stable-diffusion-inpainting",
32
- controlnet=controlnet,
33
- torch_dtype=torch.float16,
34
- cache_dir=get_hf_cache_dir(),
35
- )
 
 
 
 
 
 
 
36
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
37
  pipe.to("cuda")
38
 
39
- upscaler.load()
40
-
41
  self.pipe = pipe
 
 
 
 
42
  self.upscaler = upscaler
 
 
 
43
  self.remove_background = remove_background
44
 
 
 
 
45
  def replace(
46
  self,
47
  image: Union[str, Image.Image],
 
1
  from io import BytesIO
2
+ from typing import List, Optional, Union
3
 
4
  import torch
5
  from diffusers import (
 
17
  from internals.pipelines.remove_background import RemoveBackgroundV2
18
  from internals.pipelines.upscaler import Upscaler
19
  from internals.util.commons import download_image
20
+ from internals.util.config import get_hf_cache_dir, get_model_dir
21
 
22
 
23
  class ReplaceBackground(AbstractPipeline):
24
+ __loaded = False
25
+
26
+ def load(
27
+ self,
28
+ upscaler: Optional[Upscaler] = None,
29
+ remove_background: Optional[RemoveBackgroundV2] = None,
30
+ controlnet: Optional[ControlNet] = None,
31
+ ):
32
+ if self.__loaded:
33
+ return
34
+ controlnet_model = ControlNetModel.from_pretrained(
35
  "lllyasviel/control_v11p_sd15_lineart",
36
  torch_dtype=torch.float16,
37
  cache_dir=get_hf_cache_dir(),
38
  ).to("cuda")
39
+ if controlnet:
40
+ controlnet.load_linearart()
41
+ pipe = StableDiffusionControlNetInpaintPipeline(
42
+ **controlnet.pipe.components
43
+ )
44
+ pipe.controlnet = controlnet_model
45
+ else:
46
+ pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
47
+ get_model_dir(),
48
+ controlnet=controlnet_model,
49
+ torch_dtype=torch.float16,
50
+ cache_dir=get_hf_cache_dir(),
51
+ )
52
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
53
  pipe.to("cuda")
54
 
 
 
55
  self.pipe = pipe
56
+ if not upscaler:
57
+ upscaler = Upscaler()
58
+
59
+ upscaler.load()
60
  self.upscaler = upscaler
61
+
62
+ if not remove_background:
63
+ remove_background = RemoveBackgroundV2()
64
  self.remove_background = remove_background
65
 
66
+ self.__loaded = True
67
+
68
+ @torch.inference_mode()
69
  def replace(
70
  self,
71
  image: Union[str, Image.Image],
internals/pipelines/upscaler.py CHANGED
@@ -125,9 +125,10 @@ class Upscaler:
125
  ) -> bytes:
126
  if type(image) is str:
127
  image = download_image(image)
128
- w, h = image.size
129
- if max(w, h) > 1536:
130
- image = ImageUtil.resize_image(image, dimension=1536)
 
131
 
132
  in_path = str(Path.home() / ".cache" / "input_upscale.png")
133
  image.save(in_path)
 
125
  ) -> bytes:
126
  if type(image) is str:
127
  image = download_image(image)
128
+
129
+ w, h = image.size
130
+ if max(w, h) > 1536:
131
+ image = ImageUtil.resize_image(image, dimension=1024)
132
 
133
  in_path = str(Path.home() / ".cache" / "input_upscale.png")
134
  image.save(in_path)
internals/util/image.py CHANGED
@@ -1,5 +1,6 @@
1
  import io
2
 
 
3
  from PIL import Image
4
 
5
 
@@ -18,6 +19,26 @@ def resize_image(image: Image.Image, dimension: int = 512) -> Image.Image:
18
  return image
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def from_bytes(data: bytes) -> Image.Image:
22
  return Image.open(io.BytesIO(data))
23
 
 
1
  import io
2
 
3
+ import numpy as np
4
  from PIL import Image
5
 
6
 
 
19
  return image
20
 
21
 
22
+ def HWC3(x):
23
+ "x: numpy array"
24
+ assert x.dtype == np.uint8
25
+ if x.ndim == 2:
26
+ x = x[:, :, None]
27
+ assert x.ndim == 3
28
+ H, W, C = x.shape
29
+ assert C == 1 or C == 3 or C == 4
30
+ if C == 3:
31
+ return x
32
+ if C == 1:
33
+ return np.concatenate([x, x, x], axis=2)
34
+ if C == 4:
35
+ color = x[:, :, 0:3].astype(np.float32)
36
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
37
+ y = color * alpha + 255.0 * (1.0 - alpha)
38
+ y = y.clip(0, 255).astype(np.uint8)
39
+ return y
40
+
41
+
42
  def from_bytes(data: bytes) -> Image.Image:
43
  return Image.open(io.BytesIO(data))
44
 
internals/util/lora_style.py CHANGED
@@ -21,18 +21,26 @@ class LoraStyle:
21
 
22
  @torch.inference_mode()
23
  def patch(self):
24
- path = self.__style["path"]
25
- if str(path).endswith((".pt", ".safetensors")):
26
- patch_pipe(self.pipe, self.__style["path"])
27
- tune_lora_scale(self.pipe.unet, self.__style["weight"])
28
- tune_lora_scale(self.pipe.text_encoder, self.__style["weight"])
 
 
 
 
29
 
30
  def kwargs(self):
31
  return {}
32
 
33
  def cleanup(self):
34
- tune_lora_scale(self.pipe.unet, 0.0)
35
- tune_lora_scale(self.pipe.text_encoder, 0.0)
 
 
 
 
36
 
37
  class LoraDiffuserPatcher:
38
  def __init__(self, pipe, style: Dict[str, Any]):
@@ -41,16 +49,24 @@ class LoraStyle:
41
 
42
  @torch.inference_mode()
43
  def patch(self):
44
- path = self.__style["path"]
45
- self.pipe.load_lora_weights(
46
- os.path.dirname(path), weight_name=os.path.basename(path)
47
- )
 
 
 
 
48
 
49
  def kwargs(self):
50
  return {}
51
 
52
  def cleanup(self):
53
- LoraStyle.unload_lora_weights(self.pipe)
 
 
 
 
54
 
55
  class EmptyLoraPatcher:
56
  def __init__(self, pipe):
@@ -64,9 +80,13 @@ class LoraStyle:
64
  return {}
65
 
66
  def cleanup(self):
67
- tune_lora_scale(self.pipe.unet, 0.0)
68
- tune_lora_scale(self.pipe.text_encoder, 0.0)
69
- LoraStyle.unload_lora_weights(self.pipe)
 
 
 
 
70
 
71
  def load(self, model_dir: str):
72
  self.model = model_dir
@@ -77,8 +97,8 @@ class LoraStyle:
77
  result = getStyles()
78
  if result is not None:
79
  self.__styles = self.__parse_styles(model_dir, result["data"])
80
- else:
81
- self.__styles = self.__get_default_styles(model_dir)
82
  self.__verify()
83
 
84
  def prepend_style_to_prompt(self, prompt: str, key: str) -> str:
@@ -88,8 +108,10 @@ class LoraStyle:
88
  return prompt
89
 
90
  def get_patcher(
91
- self, pipe, key: str
92
  ) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
 
 
93
  if key in self.__styles:
94
  style = self.__styles[key]
95
  if style["type"] == "diffuser":
@@ -119,49 +141,8 @@ class LoraStyle:
119
  "text": attr["text"],
120
  "negativePrompt": attr["negativePrompt"],
121
  }
122
- if len(styles) == 0:
123
- return self.__get_default_styles(model_dir)
124
  return styles
125
 
126
- def __get_default_styles(self, model_dir: str) -> Dict:
127
- return {
128
- "nq6akX1CIp": {
129
- "path": model_dir + "/laur_style/nq6akX1CIp/final_lora.safetensors",
130
- "text": ["nq6akX1CIp style"],
131
- "weight": 0.5,
132
- "negativePrompt": [""],
133
- "type": "custom",
134
- },
135
- "ghibli": {
136
- "path": model_dir + "/laur_style/nq6akX1CIp/ghibli.bin",
137
- "text": ["ghibli style"],
138
- "weight": 1,
139
- "negativePrompt": [""],
140
- "type": "custom",
141
- },
142
- "eQAmnK2kB2": {
143
- "path": model_dir + "/laur_style/eQAmnK2kB2/final_lora.safetensors",
144
- "text": ["eQAmnK2kB2 style"],
145
- "weight": 0.5,
146
- "negativePrompt": [""],
147
- "type": "custom",
148
- },
149
- "to8contrast": {
150
- "path": model_dir + "/laur_style/rpjgusOgqD/final_lora.bin",
151
- "text": ["to8contrast style"],
152
- "weight": 0.5,
153
- "negativePrompt": [""],
154
- "type": "custom",
155
- },
156
- "sfrrfz8vge": {
157
- "path": model_dir + "/laur_style/replicate/sfrrfz8vge.safetensors",
158
- "text": ["sfrrfz8vge style"],
159
- "weight": 1.2,
160
- "negativePrompt": [""],
161
- "type": "custom",
162
- },
163
- }
164
-
165
  def __verify(self):
166
  "A method to verify if lora exists within the required path otherwise throw error"
167
 
 
21
 
22
  @torch.inference_mode()
23
  def patch(self):
24
+ def run(pipe):
25
+ path = self.__style["path"]
26
+ if str(path).endswith((".pt", ".safetensors")):
27
+ patch_pipe(pipe, self.__style["path"])
28
+ tune_lora_scale(pipe.unet, self.__style["weight"])
29
+ tune_lora_scale(pipe.text_encoder, self.__style["weight"])
30
+
31
+ for p in self.pipe:
32
+ run(p)
33
 
34
  def kwargs(self):
35
  return {}
36
 
37
  def cleanup(self):
38
+ def run(pipe):
39
+ tune_lora_scale(pipe.unet, 0.0)
40
+ tune_lora_scale(pipe.text_encoder, 0.0)
41
+
42
+ for p in self.pipe:
43
+ run(p)
44
 
45
  class LoraDiffuserPatcher:
46
  def __init__(self, pipe, style: Dict[str, Any]):
 
49
 
50
  @torch.inference_mode()
51
  def patch(self):
52
+ def run(pipe):
53
+ path = self.__style["path"]
54
+ pipe.load_lora_weights(
55
+ os.path.dirname(path), weight_name=os.path.basename(path)
56
+ )
57
+
58
+ for p in self.pipe:
59
+ run(p)
60
 
61
  def kwargs(self):
62
  return {}
63
 
64
  def cleanup(self):
65
+ def run(pipe):
66
+ LoraStyle.unload_lora_weights(pipe)
67
+
68
+ for p in self.pipe:
69
+ run(p)
70
 
71
  class EmptyLoraPatcher:
72
  def __init__(self, pipe):
 
80
  return {}
81
 
82
  def cleanup(self):
83
+ def run(pipe):
84
+ tune_lora_scale(pipe.unet, 0.0)
85
+ tune_lora_scale(pipe.text_encoder, 0.0)
86
+ LoraStyle.unload_lora_weights(pipe)
87
+
88
+ for p in self.pipe:
89
+ run(p)
90
 
91
  def load(self, model_dir: str):
92
  self.model = model_dir
 
97
  result = getStyles()
98
  if result is not None:
99
  self.__styles = self.__parse_styles(model_dir, result["data"])
100
+ if len(self.__styles) == 0:
101
+ print("Warning: No styles found for Model")
102
  self.__verify()
103
 
104
  def prepend_style_to_prompt(self, prompt: str, key: str) -> str:
 
108
  return prompt
109
 
110
  def get_patcher(
111
+ self, pipe: Union[Any, List], key: str
112
  ) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
113
+ "Returns a lora patcher for the given `key` and `pipe`. `pipe` can also be a list of pipes"
114
+ pipe = [pipe] if not isinstance(pipe, list) else pipe
115
  if key in self.__styles:
116
  style = self.__styles[key]
117
  if style["type"] == "diffuser":
 
141
  "text": attr["text"],
142
  "negativePrompt": attr["negativePrompt"],
143
  }
 
 
144
  return styles
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  def __verify(self):
147
  "A method to verify if lora exists within the required path otherwise throw error"
148
 
pyproject.toml CHANGED
@@ -1,4 +1,4 @@
1
  [tool.pyright]
2
- venvPath = "."
3
  venv = "env"
4
- exclude = "env"
 
1
  [tool.pyright]
2
+ venvPath = "/Users/devel/Documents/WebProjects/creco-inference"
3
  venv = "env"
4
+ exclude = ["env"]