ovshake commited on
Commit
6724ca0
1 Parent(s): b17e19b

add app.py and related files

Browse files
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from diffusers import StableDiffusionInpaintPipeline
3
+ import os
4
+
5
+ from tqdm import tqdm
6
+ from PIL import Image
7
+ import numpy as np
8
+ import cv2
9
+ import warnings
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ warnings.filterwarnings("ignore", category=FutureWarning)
13
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import torchvision.transforms as transforms
18
+
19
+ from data.base_dataset import Normalize_image
20
+ from utils.saving_utils import load_checkpoint_mgpu
21
+ from networks import U2NET
22
+ import argparse
23
+ from enum import Enum
24
+ from rembg import remove
25
+ from dataclasses import dataclass
26
+
27
+
28
+ @dataclass
29
+ class StableFashionCLIArgs:
30
+ image
31
+ part
32
+ resolution
33
+ promt
34
+ num_steps
35
+ guidance_scale
36
+ rembg
37
+
38
+
39
+ class Parts:
40
+ UPPER = 1
41
+ LOWER = 2
42
+
43
+
44
+ def load_u2net():
45
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ checkpoint_path = hf_hub_download(repo_id="maiti/cloth-segmentation", filename="cloth_segm_u2net_latest.pth")
47
+ net = U2NET(in_ch=3, out_ch=4)
48
+ net = load_checkpoint_mgpu(net, checkpoint_path)
49
+ net = net.to(device)
50
+ net = net.eval()
51
+ return net
52
+
53
+ def change_bg_color(rgba_image, color):
54
+ new_image = Image.new("RGBA", rgba_image.size, color)
55
+ new_image.paste(rgba_image, (0, 0), rgba_image)
56
+ return new_image.convert("RGB")
57
+
58
+ def load_inpainting_pipeline():
59
+ device = "cuda" if torch.cuda.is_available() else "cpu"
60
+ inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
61
+ "runwayml/stable-diffusion-inpainting",
62
+ revision="fp16",
63
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
64
+ ).to(device)
65
+ return inpainting_pipeline
66
+
67
+
68
+ def process_image(args, inpainting_pipeline, net):
69
+ device = "cuda" if torch.cuda.is_available() else "cpu"
70
+ image_path = args.image
71
+ transforms_list = []
72
+ transforms_list += [transforms.ToTensor()]
73
+ transforms_list += [Normalize_image(0.5, 0.5)]
74
+ transform_rgb = transforms.Compose(transforms_list)
75
+ img = Image.open(image_path)
76
+ img = img.convert("RGB")
77
+ img = img.resize((args.resolution, args.resolution))
78
+ if args.rembg:
79
+ img_with_green_bg = remove(img)
80
+ img_with_green_bg = change_bg_color(img_with_green_bg, color="GREEN")
81
+ img_with_green_bg = img_with_green_bg.convert("RGB")
82
+ else:
83
+ img_with_green_bg = img
84
+ image_tensor = transform_rgb(img_with_green_bg)
85
+ image_tensor = image_tensor.unsqueeze(0)
86
+ output_tensor = net(image_tensor.to(device))
87
+ output_tensor = F.log_softmax(output_tensor[0], dim=1)
88
+ output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
89
+ output_tensor = torch.squeeze(output_tensor, dim=0)
90
+ output_tensor = torch.squeeze(output_tensor, dim=0)
91
+ output_arr = output_tensor.cpu().numpy()
92
+ mask_code = eval(f"Parts.{args.part.upper()}")
93
+ mask = (output_arr == mask_code)
94
+ output_arr[mask] = 1
95
+ output_arr[~mask] = 0
96
+ output_arr *= 255
97
+ mask_PIL = Image.fromarray(output_arr.astype("uint8"), mode="L")
98
+ clothed_image_from_pipeline = inpainting_pipeline(prompt=args.prompt,
99
+ image=img_with_green_bg,
100
+ mask_image=mask_PIL,
101
+ width=args.resolution,
102
+ height=args.resolution,
103
+ guidance_scale=args.guidance_scale,
104
+ num_inference_steps=args.num_steps).images[0]
105
+ clothed_image_from_pipeline = remove(clothed_image_from_pipeline)
106
+ clothed_image_from_pipeline = change_bg_color(clothed_image_from_pipeline, "WHITE")
107
+ return clothed_image_from_pipeline.convert("RGB")
108
+
109
+
110
+ st.title("Stable Fashion Huggingface Spaces")
111
+ file_name = st.file_uploader("Upload a clear full length picture of yourself, preferably in a less noisy background")
112
+ net = load_u2net()
113
+ inpainting_pipeline = load_inpainting_pipeline()
114
+
115
+ if file_name is not None:
116
+ image = Image.open(file_name)
117
+ stable_fashion_args = StableFashionCLIArgs()
118
+ stable_fashion_args.image = image
119
+ body_part = st.radio("Would you like to try clothes on your upper body (such as shirts, kurtas etc) or lower (Jeans, Pants etc)? ", ('Upper', 'Lower'))
120
+ stable_fashion_args.part = body_part
121
+ resolution = st.radio("Which resolution would you like to get the resulting picture in? (Keep in mind, higher the resolution, higher the queue times)", (128, 256, 512))
122
+ stable_fashion_args.resolution = resolution
123
+ rembg_status = st.radio("Would you like to remove background in your image before putting new clothes on you? (Sometimes it results in better images)", ("Yes", "No"))
124
+ stable_fashion_args.rembg = (rembg_status == "Yes")
125
+ guidance_scale = st.slider("Select a guidance scale. 7.5 gives the best results.", 1.0, 15.0, value=7.5)
126
+ stable_fashion_args.guidance_scale = guidance_scale
127
+ prompt = st.text_input('Write the description of cloth you want to try', 'a bright yellow t shirt')
128
+ stable_fashion_args.prompt = guidance_scale
129
+
130
+ num_steps = st.slider("No. of inference steps for the diffusion process", 5, 50, value=25)
131
+
132
+
133
+ result_image = process_image(stable_fashion_args, inpainting_pipeline, net)
134
+ st.image(result_image, caption='Sunrise by the mountains')
135
+
136
+
137
+
data/__pycache__/base_dataset.cpython-39.pyc ADDED
Binary file (5.75 kB). View file
 
data/base_dataset.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import cv2
4
+ import numpy as np
5
+ import random
6
+
7
+ import torch
8
+ import torch.utils.data as data
9
+ import torchvision.transforms as transforms
10
+
11
+
12
+ class BaseDataset(data.Dataset):
13
+ def __init__(self):
14
+ super(BaseDataset, self).__init__()
15
+
16
+ def name(self):
17
+ return "BaseDataset"
18
+
19
+ def initialize(self, opt):
20
+ pass
21
+
22
+
23
+ class Rescale_fixed(object):
24
+ """Rescale the input image into given size.
25
+
26
+ Args:
27
+ (w,h) (tuple): output size or x (int) then resized will be done in (x,x).
28
+ """
29
+
30
+ def __init__(self, output_size):
31
+ self.output_size = output_size
32
+
33
+ def __call__(self, image):
34
+ return image.resize(self.output_size, Image.BICUBIC)
35
+
36
+
37
+ class Rescale_custom(object):
38
+ """Rescale the input image and target image into randomly selected size with lower bound of min_size arg.
39
+
40
+ Args:
41
+ min_size (int): Minimum desired output size.
42
+ """
43
+
44
+ def __init__(self, min_size, max_size):
45
+ assert isinstance(min_size, (int, float))
46
+ self.min_size = min_size
47
+ self.max_size = max_size
48
+
49
+ def __call__(self, sample):
50
+
51
+ input_image, target_image = sample["input_image"], sample["target_image"]
52
+
53
+ assert input_image.size == target_image.size
54
+ w, h = input_image.size
55
+
56
+ # Randomly select size to resize
57
+ if min(self.max_size, h, w) > self.min_size:
58
+ self.output_size = np.random.randint(
59
+ self.min_size, min(self.max_size, h, w)
60
+ )
61
+ else:
62
+ self.output_size = self.min_size
63
+
64
+ # calculate new size by keeping aspect ratio same
65
+ if h > w:
66
+ new_h, new_w = self.output_size * h / w, self.output_size
67
+ else:
68
+ new_h, new_w = self.output_size, self.output_size * w / h
69
+
70
+ new_w, new_h = int(new_w), int(new_h)
71
+ input_image = input_image.resize((new_w, new_h), Image.BICUBIC)
72
+ target_image = target_image.resize((new_w, new_h), Image.BICUBIC)
73
+ return {"input_image": input_image, "target_image": target_image}
74
+
75
+
76
+ class ToTensor(object):
77
+ """Convert ndarrays in sample to Tensors."""
78
+
79
+ def __init__(self):
80
+ self.totensor = transforms.ToTensor()
81
+
82
+ def __call__(self, sample):
83
+ input_image, target_image = sample["input_image"], sample["target_image"]
84
+
85
+ return {
86
+ "input_image": self.totensor(input_image),
87
+ "target_image": self.totensor(target_image),
88
+ }
89
+
90
+
91
+ class RandomCrop_custom(object):
92
+ """Crop randomly the image in a sample.
93
+
94
+ Args:
95
+ output_size (tuple or int): Desired output size. If int, square crop
96
+ is made.
97
+ """
98
+
99
+ def __init__(self, output_size):
100
+ assert isinstance(output_size, (int, tuple))
101
+ if isinstance(output_size, int):
102
+ self.output_size = (output_size, output_size)
103
+ else:
104
+ assert len(output_size) == 2
105
+ self.output_size = output_size
106
+
107
+ self.randomcrop = transforms.RandomCrop(self.output_size)
108
+
109
+ def __call__(self, sample):
110
+ input_image, target_image = sample["input_image"], sample["target_image"]
111
+ cropped_imgs = self.randomcrop(torch.cat((input_image, target_image)))
112
+
113
+ return {
114
+ "input_image": cropped_imgs[
115
+ :3,
116
+ :,
117
+ ],
118
+ "target_image": cropped_imgs[
119
+ 3:,
120
+ :,
121
+ ],
122
+ }
123
+
124
+
125
+ class Normalize_custom(object):
126
+ """Normalize given dict into given mean and standard dev
127
+
128
+ Args:
129
+ mean (tuple or int): Desired mean to substract from dict's tensors
130
+ std (tuple or int): Desired std to divide from dict's tensors
131
+ """
132
+
133
+ def __init__(self, mean, std):
134
+ assert isinstance(mean, (float, tuple))
135
+ if isinstance(mean, float):
136
+ self.mean = (mean, mean, mean)
137
+ else:
138
+ assert len(mean) == 3
139
+ self.mean = mean
140
+
141
+ if isinstance(std, float):
142
+ self.std = (std, std, std)
143
+ else:
144
+ assert len(std) == 3
145
+ self.std = std
146
+
147
+ self.normalize = transforms.Normalize(self.mean, self.std)
148
+
149
+ def __call__(self, sample):
150
+ input_image, target_image = sample["input_image"], sample["target_image"]
151
+
152
+ return {
153
+ "input_image": self.normalize(input_image),
154
+ "target_image": self.normalize(target_image),
155
+ }
156
+
157
+
158
+ class Normalize_image(object):
159
+ """Normalize given tensor into given mean and standard dev
160
+
161
+ Args:
162
+ mean (float): Desired mean to substract from tensors
163
+ std (float): Desired std to divide from tensors
164
+ """
165
+
166
+ def __init__(self, mean, std):
167
+ assert isinstance(mean, (float))
168
+ if isinstance(mean, float):
169
+ self.mean = mean
170
+
171
+ if isinstance(std, float):
172
+ self.std = std
173
+
174
+ self.normalize_1 = transforms.Normalize(self.mean, self.std)
175
+ self.normalize_3 = transforms.Normalize([self.mean] * 3, [self.std] * 3)
176
+ self.normalize_18 = transforms.Normalize([self.mean] * 18, [self.std] * 18)
177
+
178
+ def __call__(self, image_tensor):
179
+ if image_tensor.shape[0] == 1:
180
+ return self.normalize_1(image_tensor)
181
+
182
+ elif image_tensor.shape[0] == 3:
183
+ return self.normalize_3(image_tensor)
184
+
185
+ elif image_tensor.shape[0] == 18:
186
+ return self.normalize_18(image_tensor)
187
+
188
+ else:
189
+ assert "Please set proper channels! Normlization implemented only for 1, 3 and 18"
main.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionInpaintPipeline
2
+ import os
3
+
4
+ from tqdm import tqdm
5
+ from PIL import Image
6
+ import numpy as np
7
+ import cv2
8
+ import warnings
9
+
10
+ warnings.filterwarnings("ignore", category=FutureWarning)
11
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import torchvision.transforms as transforms
16
+
17
+ from data.base_dataset import Normalize_image
18
+ from utils.saving_utils import load_checkpoint_mgpu
19
+ from networks import U2NET
20
+ import argparse
21
+ from enum import Enum
22
+ from rembg import remove
23
+
24
+ class Parts:
25
+ UPPER = 1
26
+ LOWER = 2
27
+
28
+ def parse_arguments():
29
+ parser = argparse.ArgumentParser(
30
+ description="Stable Fashion API, allows you to picture yourself in any cloth your imagination can think of!"
31
+ )
32
+ parser.add_argument('--image', type=str, required=True, help='path to image')
33
+ parser.add_argument('--part', choices=['upper', 'lower'], default='upper', type=str)
34
+ parser.add_argument('--resolution', choices=[256, 512, 1024, 2048], default=256, type=int)
35
+ parser.add_argument('--prompt', type=str, default="A pink cloth")
36
+ parser.add_argument('--num_steps', type=int, default=5)
37
+ parser.add_argument('--guidance_scale', type=float, default=7.5)
38
+ parser.add_argument('--rembg', action='store_true')
39
+ parser.add_argument('--output', default='output.jpg', type=str)
40
+ args, _ = parser.parse_known_args()
41
+ return args
42
+
43
+
44
+ def load_u2net():
45
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ checkpoint_path = os.path.join("trained_checkpoint", "cloth_segm_u2net_latest.pth")
47
+ net = U2NET(in_ch=3, out_ch=4)
48
+ net = load_checkpoint_mgpu(net, checkpoint_path)
49
+ net = net.to(device)
50
+ net = net.eval()
51
+ return net
52
+
53
+ def change_bg_color(rgba_image, color):
54
+ new_image = Image.new("RGBA", rgba_image.size, color)
55
+ new_image.paste(rgba_image, (0, 0), rgba_image)
56
+ return new_image.convert("RGB")
57
+
58
+
59
+ def load_inpainting_pipeline():
60
+ device = "cuda" if torch.cuda.is_available() else "cpu"
61
+ inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
62
+ "runwayml/stable-diffusion-inpainting",
63
+ revision="fp16",
64
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
65
+ ).to(device)
66
+ return inpainting_pipeline
67
+ def process_image(args, inpainting_pipeline, net):
68
+ device = "cuda" if torch.cuda.is_available() else "cpu"
69
+ image_path = args.image
70
+ transforms_list = []
71
+ transforms_list += [transforms.ToTensor()]
72
+ transforms_list += [Normalize_image(0.5, 0.5)]
73
+ transform_rgb = transforms.Compose(transforms_list)
74
+ img = Image.open(image_path)
75
+ img = img.convert("RGB")
76
+ img = img.resize((args.resolution, args.resolution))
77
+ if args.rembg:
78
+ img_with_green_bg = remove(img)
79
+ img_with_green_bg = change_bg_color(img_with_green_bg, color="GREEN")
80
+ img_with_green_bg = img_with_green_bg.convert("RGB")
81
+ else:
82
+ img_with_green_bg = img
83
+ image_tensor = transform_rgb(img_with_green_bg)
84
+ image_tensor = image_tensor.unsqueeze(0)
85
+ output_tensor = net(image_tensor.to(device))
86
+ output_tensor = F.log_softmax(output_tensor[0], dim=1)
87
+ output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
88
+ output_tensor = torch.squeeze(output_tensor, dim=0)
89
+ output_tensor = torch.squeeze(output_tensor, dim=0)
90
+ output_arr = output_tensor.cpu().numpy()
91
+ mask_code = eval(f"Parts.{args.part.upper()}")
92
+ mask = (output_arr == mask_code)
93
+ output_arr[mask] = 1
94
+ output_arr[~mask] = 0
95
+ output_arr *= 255
96
+ mask_PIL = Image.fromarray(output_arr.astype("uint8"), mode="L")
97
+ clothed_image_from_pipeline = inpainting_pipeline(prompt=args.prompt,
98
+ image=img_with_green_bg,
99
+ mask_image=mask_PIL,
100
+ width=args.resolution,
101
+ height=args.resolution,
102
+ guidance_scale=args.guidance_scale,
103
+ num_inference_steps=args.num_steps).images[0]
104
+ clothed_image_from_pipeline = remove(clothed_image_from_pipeline)
105
+ clothed_image_from_pipeline = change_bg_color(clothed_image_from_pipeline, "WHITE")
106
+ return clothed_image_from_pipeline.convert("RGB")
107
+ if __name__ == '__main__':
108
+ args = parse_arguments()
109
+ net = load_u2net()
110
+ inpainting_pipeline = load_inpainting_pipeline()
111
+ result_image = process_image(args, inpainting_pipeline, net)
112
+ result_image.save(args.output)
networks/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .u2net import U2NET
networks/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (204 Bytes). View file
 
networks/__pycache__/u2net.cpython-39.pyc ADDED
Binary file (10.5 kB). View file
 
networks/u2net.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class REBNCONV(nn.Module):
7
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
8
+ super(REBNCONV, self).__init__()
9
+
10
+ self.conv_s1 = nn.Conv2d(
11
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
12
+ )
13
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
14
+ self.relu_s1 = nn.ReLU(inplace=True)
15
+
16
+ def forward(self, x):
17
+
18
+ hx = x
19
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
20
+
21
+ return xout
22
+
23
+
24
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
25
+ def _upsample_like(src, tar):
26
+
27
+ src = F.upsample(src, size=tar.shape[2:], mode="bilinear")
28
+
29
+ return src
30
+
31
+
32
+ ### RSU-7 ###
33
+ class RSU7(nn.Module): # UNet07DRES(nn.Module):
34
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
35
+ super(RSU7, self).__init__()
36
+
37
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
38
+
39
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
40
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
41
+
42
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
43
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
44
+
45
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
46
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
47
+
48
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
49
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
50
+
51
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
52
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
53
+
54
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
55
+
56
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
57
+
58
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
59
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
60
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
61
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
62
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
63
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
64
+
65
+ def forward(self, x):
66
+
67
+ hx = x
68
+ hxin = self.rebnconvin(hx)
69
+
70
+ hx1 = self.rebnconv1(hxin)
71
+ hx = self.pool1(hx1)
72
+
73
+ hx2 = self.rebnconv2(hx)
74
+ hx = self.pool2(hx2)
75
+
76
+ hx3 = self.rebnconv3(hx)
77
+ hx = self.pool3(hx3)
78
+
79
+ hx4 = self.rebnconv4(hx)
80
+ hx = self.pool4(hx4)
81
+
82
+ hx5 = self.rebnconv5(hx)
83
+ hx = self.pool5(hx5)
84
+
85
+ hx6 = self.rebnconv6(hx)
86
+
87
+ hx7 = self.rebnconv7(hx6)
88
+
89
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
90
+ hx6dup = _upsample_like(hx6d, hx5)
91
+
92
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
93
+ hx5dup = _upsample_like(hx5d, hx4)
94
+
95
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
96
+ hx4dup = _upsample_like(hx4d, hx3)
97
+
98
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
99
+ hx3dup = _upsample_like(hx3d, hx2)
100
+
101
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
102
+ hx2dup = _upsample_like(hx2d, hx1)
103
+
104
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
105
+
106
+ """
107
+ del hx1, hx2, hx3, hx4, hx5, hx6, hx7
108
+ del hx6d, hx5d, hx3d, hx2d
109
+ del hx2dup, hx3dup, hx4dup, hx5dup, hx6dup
110
+ """
111
+
112
+ return hx1d + hxin
113
+
114
+
115
+ ### RSU-6 ###
116
+ class RSU6(nn.Module): # UNet06DRES(nn.Module):
117
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
118
+ super(RSU6, self).__init__()
119
+
120
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
121
+
122
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
123
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
124
+
125
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
126
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
+
128
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
+
131
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
133
+
134
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
135
+
136
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
137
+
138
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
139
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
140
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
141
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
142
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
143
+
144
+ def forward(self, x):
145
+
146
+ hx = x
147
+
148
+ hxin = self.rebnconvin(hx)
149
+
150
+ hx1 = self.rebnconv1(hxin)
151
+ hx = self.pool1(hx1)
152
+
153
+ hx2 = self.rebnconv2(hx)
154
+ hx = self.pool2(hx2)
155
+
156
+ hx3 = self.rebnconv3(hx)
157
+ hx = self.pool3(hx3)
158
+
159
+ hx4 = self.rebnconv4(hx)
160
+ hx = self.pool4(hx4)
161
+
162
+ hx5 = self.rebnconv5(hx)
163
+
164
+ hx6 = self.rebnconv6(hx5)
165
+
166
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
167
+ hx5dup = _upsample_like(hx5d, hx4)
168
+
169
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
170
+ hx4dup = _upsample_like(hx4d, hx3)
171
+
172
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
173
+ hx3dup = _upsample_like(hx3d, hx2)
174
+
175
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
176
+ hx2dup = _upsample_like(hx2d, hx1)
177
+
178
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
179
+
180
+ """
181
+ del hx1, hx2, hx3, hx4, hx5, hx6
182
+ del hx5d, hx4d, hx3d, hx2d
183
+ del hx2dup, hx3dup, hx4dup, hx5dup
184
+ """
185
+
186
+ return hx1d + hxin
187
+
188
+
189
+ ### RSU-5 ###
190
+ class RSU5(nn.Module): # UNet05DRES(nn.Module):
191
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
192
+ super(RSU5, self).__init__()
193
+
194
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
195
+
196
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
197
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
198
+
199
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
200
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
201
+
202
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
203
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
204
+
205
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
206
+
207
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
208
+
209
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
210
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
211
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
212
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
213
+
214
+ def forward(self, x):
215
+
216
+ hx = x
217
+
218
+ hxin = self.rebnconvin(hx)
219
+
220
+ hx1 = self.rebnconv1(hxin)
221
+ hx = self.pool1(hx1)
222
+
223
+ hx2 = self.rebnconv2(hx)
224
+ hx = self.pool2(hx2)
225
+
226
+ hx3 = self.rebnconv3(hx)
227
+ hx = self.pool3(hx3)
228
+
229
+ hx4 = self.rebnconv4(hx)
230
+
231
+ hx5 = self.rebnconv5(hx4)
232
+
233
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
234
+ hx4dup = _upsample_like(hx4d, hx3)
235
+
236
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
237
+ hx3dup = _upsample_like(hx3d, hx2)
238
+
239
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
240
+ hx2dup = _upsample_like(hx2d, hx1)
241
+
242
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
243
+
244
+ """
245
+ del hx1, hx2, hx3, hx4, hx5
246
+ del hx4d, hx3d, hx2d
247
+ del hx2dup, hx3dup, hx4dup
248
+ """
249
+
250
+ return hx1d + hxin
251
+
252
+
253
+ ### RSU-4 ###
254
+ class RSU4(nn.Module): # UNet04DRES(nn.Module):
255
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
256
+ super(RSU4, self).__init__()
257
+
258
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
259
+
260
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
261
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
262
+
263
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
264
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
265
+
266
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
267
+
268
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
269
+
270
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
271
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
272
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
273
+
274
+ def forward(self, x):
275
+
276
+ hx = x
277
+
278
+ hxin = self.rebnconvin(hx)
279
+
280
+ hx1 = self.rebnconv1(hxin)
281
+ hx = self.pool1(hx1)
282
+
283
+ hx2 = self.rebnconv2(hx)
284
+ hx = self.pool2(hx2)
285
+
286
+ hx3 = self.rebnconv3(hx)
287
+
288
+ hx4 = self.rebnconv4(hx3)
289
+
290
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
291
+ hx3dup = _upsample_like(hx3d, hx2)
292
+
293
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
294
+ hx2dup = _upsample_like(hx2d, hx1)
295
+
296
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
297
+
298
+ """
299
+ del hx1, hx2, hx3, hx4
300
+ del hx3d, hx2d
301
+ del hx2dup, hx3dup
302
+ """
303
+
304
+ return hx1d + hxin
305
+
306
+
307
+ ### RSU-4F ###
308
+ class RSU4F(nn.Module): # UNet04FRES(nn.Module):
309
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
310
+ super(RSU4F, self).__init__()
311
+
312
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
313
+
314
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
315
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
316
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
317
+
318
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
319
+
320
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
321
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
322
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
323
+
324
+ def forward(self, x):
325
+
326
+ hx = x
327
+
328
+ hxin = self.rebnconvin(hx)
329
+
330
+ hx1 = self.rebnconv1(hxin)
331
+ hx2 = self.rebnconv2(hx1)
332
+ hx3 = self.rebnconv3(hx2)
333
+
334
+ hx4 = self.rebnconv4(hx3)
335
+
336
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
337
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
338
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
339
+
340
+ """
341
+ del hx1, hx2, hx3, hx4
342
+ del hx3d, hx2d
343
+ """
344
+
345
+ return hx1d + hxin
346
+
347
+
348
+ ##### U^2-Net ####
349
+ class U2NET(nn.Module):
350
+ def __init__(self, in_ch=3, out_ch=1):
351
+ super(U2NET, self).__init__()
352
+
353
+ self.stage1 = RSU7(in_ch, 32, 64)
354
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
355
+
356
+ self.stage2 = RSU6(64, 32, 128)
357
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
358
+
359
+ self.stage3 = RSU5(128, 64, 256)
360
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
361
+
362
+ self.stage4 = RSU4(256, 128, 512)
363
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
364
+
365
+ self.stage5 = RSU4F(512, 256, 512)
366
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
367
+
368
+ self.stage6 = RSU4F(512, 256, 512)
369
+
370
+ # decoder
371
+ self.stage5d = RSU4F(1024, 256, 512)
372
+ self.stage4d = RSU4(1024, 128, 256)
373
+ self.stage3d = RSU5(512, 64, 128)
374
+ self.stage2d = RSU6(256, 32, 64)
375
+ self.stage1d = RSU7(128, 16, 64)
376
+
377
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
378
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
379
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
380
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
381
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
382
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
383
+
384
+ self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
385
+
386
+ def forward(self, x):
387
+
388
+ hx = x
389
+
390
+ # stage 1
391
+ hx1 = self.stage1(hx)
392
+ hx = self.pool12(hx1)
393
+
394
+ # stage 2
395
+ hx2 = self.stage2(hx)
396
+ hx = self.pool23(hx2)
397
+
398
+ # stage 3
399
+ hx3 = self.stage3(hx)
400
+ hx = self.pool34(hx3)
401
+
402
+ # stage 4
403
+ hx4 = self.stage4(hx)
404
+ hx = self.pool45(hx4)
405
+
406
+ # stage 5
407
+ hx5 = self.stage5(hx)
408
+ hx = self.pool56(hx5)
409
+
410
+ # stage 6
411
+ hx6 = self.stage6(hx)
412
+ hx6up = _upsample_like(hx6, hx5)
413
+
414
+ # -------------------- decoder --------------------
415
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
416
+ hx5dup = _upsample_like(hx5d, hx4)
417
+
418
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
419
+ hx4dup = _upsample_like(hx4d, hx3)
420
+
421
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
422
+ hx3dup = _upsample_like(hx3d, hx2)
423
+
424
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
425
+ hx2dup = _upsample_like(hx2d, hx1)
426
+
427
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
428
+
429
+ # side output
430
+ d1 = self.side1(hx1d)
431
+
432
+ d2 = self.side2(hx2d)
433
+ d2 = _upsample_like(d2, d1)
434
+
435
+ d3 = self.side3(hx3d)
436
+ d3 = _upsample_like(d3, d1)
437
+
438
+ d4 = self.side4(hx4d)
439
+ d4 = _upsample_like(d4, d1)
440
+
441
+ d5 = self.side5(hx5d)
442
+ d5 = _upsample_like(d5, d1)
443
+
444
+ d6 = self.side6(hx6)
445
+ d6 = _upsample_like(d6, d1)
446
+
447
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
448
+
449
+ """
450
+ del hx1, hx2, hx3, hx4, hx5, hx6
451
+ del hx5d, hx4d, hx3d, hx2d, hx1d
452
+ del hx6up, hx5dup, hx4dup, hx3dup, hx2dup
453
+ """
454
+
455
+ return d0, d1, d2, d3, d4, d5, d6
456
+
457
+
458
+ ### U^2-Net small ###
459
+ class U2NETP(nn.Module):
460
+ def __init__(self, in_ch=3, out_ch=1):
461
+ super(U2NETP, self).__init__()
462
+
463
+ self.stage1 = RSU7(in_ch, 16, 64)
464
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
465
+
466
+ self.stage2 = RSU6(64, 16, 64)
467
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
468
+
469
+ self.stage3 = RSU5(64, 16, 64)
470
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
471
+
472
+ self.stage4 = RSU4(64, 16, 64)
473
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
474
+
475
+ self.stage5 = RSU4F(64, 16, 64)
476
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
477
+
478
+ self.stage6 = RSU4F(64, 16, 64)
479
+
480
+ # decoder
481
+ self.stage5d = RSU4F(128, 16, 64)
482
+ self.stage4d = RSU4(128, 16, 64)
483
+ self.stage3d = RSU5(128, 16, 64)
484
+ self.stage2d = RSU6(128, 16, 64)
485
+ self.stage1d = RSU7(128, 16, 64)
486
+
487
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
488
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
489
+ self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
490
+ self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
491
+ self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
492
+ self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
493
+
494
+ self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
495
+
496
+ def forward(self, x):
497
+
498
+ hx = x
499
+
500
+ # stage 1
501
+ hx1 = self.stage1(hx)
502
+ hx = self.pool12(hx1)
503
+
504
+ # stage 2
505
+ hx2 = self.stage2(hx)
506
+ hx = self.pool23(hx2)
507
+
508
+ # stage 3
509
+ hx3 = self.stage3(hx)
510
+ hx = self.pool34(hx3)
511
+
512
+ # stage 4
513
+ hx4 = self.stage4(hx)
514
+ hx = self.pool45(hx4)
515
+
516
+ # stage 5
517
+ hx5 = self.stage5(hx)
518
+ hx = self.pool56(hx5)
519
+
520
+ # stage 6
521
+ hx6 = self.stage6(hx)
522
+ hx6up = _upsample_like(hx6, hx5)
523
+
524
+ # decoder
525
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
526
+ hx5dup = _upsample_like(hx5d, hx4)
527
+
528
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
529
+ hx4dup = _upsample_like(hx4d, hx3)
530
+
531
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
532
+ hx3dup = _upsample_like(hx3d, hx2)
533
+
534
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
535
+ hx2dup = _upsample_like(hx2d, hx1)
536
+
537
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
538
+
539
+ # side output
540
+ d1 = self.side1(hx1d)
541
+
542
+ d2 = self.side2(hx2d)
543
+ d2 = _upsample_like(d2, d1)
544
+
545
+ d3 = self.side3(hx3d)
546
+ d3 = _upsample_like(d3, d1)
547
+
548
+ d4 = self.side4(hx4d)
549
+ d4 = _upsample_like(d4, d1)
550
+
551
+ d5 = self.side5(hx5d)
552
+ d5 = _upsample_like(d5, d1)
553
+
554
+ d6 = self.side6(hx6)
555
+ d6 = _upsample_like(d6, d1)
556
+
557
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
558
+
559
+ """
560
+ del hx1, hx2, hx3, hx4, hx5, hx6
561
+ del hx5d, hx4d, hx3d, hx2d, hx1d
562
+ del hx6up, hx5dup, hx4dup, hx3dup, hx2dup
563
+ """
564
+
565
+ return d0, d1, d2, d3, d4, d5, d6
requirements.txt ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.2.0
2
+ accelerate==0.12.0
3
+ aiohttp==3.8.1
4
+ aiosignal==1.2.0
5
+ asttokens==2.0.8
6
+ async-timeout==4.0.2
7
+ attrs==22.1.0
8
+ backcall==0.2.0
9
+ beautifulsoup4==4.11.1
10
+ bitsandbytes==0.33.1
11
+ cachetools==5.2.0
12
+ charset-normalizer==2.1.1
13
+ clip @ git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
14
+ datasets==2.4.0
15
+ decorator==5.1.1
16
+ diffusers @ git+https://github.com/ovshake/diffusers@5bbecb751764248755943b57d900ae14a7f43a75
17
+ dill==0.3.5.1
18
+ executing==1.0.0
19
+ filelock==3.8.0
20
+ frozenlist==1.3.1
21
+ fsspec==2022.8.2
22
+ ftfy==6.1.1
23
+ gdown==4.5.1
24
+ google-auth==2.11.0
25
+ google-auth-oauthlib==0.4.6
26
+ grpcio==1.47.0
27
+ huggingface-hub==0.11.0
28
+ idna==3.3
29
+ importlib-metadata==4.12.0
30
+ ipdb==0.13.9
31
+ ipython==8.4.0
32
+ jedi==0.18.1
33
+ Jinja2==3.1.2
34
+ joblib==1.1.0
35
+ Markdown==3.4.1
36
+ MarkupSafe==2.1.1
37
+ matplotlib-inline==0.1.6
38
+ modelcards==0.1.6
39
+ multidict==6.0.2
40
+ multiprocess==0.70.13
41
+ numpy==1.23.2
42
+ oauthlib==3.2.0
43
+ opencv-python==4.6.0.66
44
+ packaging==21.3
45
+ pandas==1.4.4
46
+ parso==0.8.3
47
+ pexpect==4.8.0
48
+ pickleshare==0.7.5
49
+ Pillow==9.2.0
50
+ prompt-toolkit==3.0.30
51
+ protobuf==3.19.4
52
+ psutil==5.9.1
53
+ ptyprocess==0.7.0
54
+ pure-eval==0.2.2
55
+ pyarrow==9.0.0
56
+ pyasn1==0.4.8
57
+ pyasn1-modules==0.2.8
58
+ Pygments==2.13.0
59
+ pyparsing==3.0.9
60
+ PySocks==1.7.1
61
+ python-dateutil==2.8.2
62
+ pytz==2022.2.1
63
+ PyYAML==6.0
64
+ regex==2022.8.17
65
+ requests==2.28.1
66
+ requests-oauthlib==1.3.1
67
+ responses==0.18.0
68
+ rsa==4.9
69
+ six==1.16.0
70
+ soupsieve==2.3.2.post1
71
+ stack-data==0.5.0
72
+ tensorboard==2.10.0
73
+ tensorboard-data-server==0.6.1
74
+ tensorboard-plugin-wit==1.8.1
75
+ tensorboardX==2.5.1
76
+ tokenizers==0.13.2
77
+ toml==0.10.2
78
+ torch==1.13.0+cu116
79
+ torchaudio==0.13.0+cu116
80
+ torchvision==0.14.0+cu116
81
+ tqdm==4.64.0
82
+ traitlets==5.3.0
83
+ transformers==4.24.0
84
+ typing_extensions==4.3.0
85
+ urllib3==1.26.12
86
+ wcwidth==0.2.5
87
+ Werkzeug==2.2.2
88
+ xxhash==3.0.0
89
+ yarl==1.8.1
90
+ zipp==3.8.1
91
+ rembg
utils/__pycache__/saving_utils.cpython-39.pyc ADDED
Binary file (1.53 kB). View file
 
utils/saving_utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import cv2
4
+ import numpy as np
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+
9
+
10
+ def load_checkpoint(model, checkpoint_path):
11
+ if not os.path.exists(checkpoint_path):
12
+ print("----No checkpoints at given path----")
13
+ return
14
+ model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device("cpu")))
15
+ print("----checkpoints loaded from path: {}----".format(checkpoint_path))
16
+ return model
17
+
18
+
19
+ def load_checkpoint_mgpu(model, checkpoint_path):
20
+ if not os.path.exists(checkpoint_path):
21
+ print("----No checkpoints at given path----")
22
+ return
23
+ model_state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
24
+ new_state_dict = OrderedDict()
25
+ for k, v in model_state_dict.items():
26
+ name = k[7:] # remove `module.`
27
+ new_state_dict[name] = v
28
+
29
+ model.load_state_dict(new_state_dict)
30
+ print("----checkpoints loaded from path: {}----".format(checkpoint_path))
31
+ return model
32
+
33
+
34
+ def save_checkpoint(model, save_path):
35
+ print(save_path)
36
+ if not os.path.exists(os.path.dirname(save_path)):
37
+ os.makedirs(os.path.dirname(save_path))
38
+ torch.save(model.state_dict(), save_path)
39
+
40
+
41
+ def save_checkpoints(opt, itr, net):
42
+ save_checkpoint(
43
+ net,
44
+ os.path.join(opt.save_dir, "checkpoints", "itr_{:08d}_u2net.pth".format(itr)),
45
+ )