Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- u2net/data_loader.py +26 -11
- u2net/evaluate.py +6 -3
- u2net/model.py +8 -5
- u2net/train.py +44 -23
u2net/data_loader.py
CHANGED
@@ -3,12 +3,13 @@ import random
|
|
3 |
from PIL import Image
|
4 |
|
5 |
import torch
|
|
|
6 |
from torchvision import transforms
|
7 |
from sklearn.model_selection import train_test_split
|
8 |
-
|
9 |
|
10 |
class SaliencyDataset(torch.utils.data.Dataset):
|
11 |
-
def __init__(self, split, img_size=512, val_split_ratio=0.05):
|
12 |
self.img_size = img_size
|
13 |
self.split = split
|
14 |
self.image_dir, self.mask_dir = self.set_directories(split)
|
@@ -19,8 +20,14 @@ class SaliencyDataset(torch.utils.data.Dataset):
|
|
19 |
self.images = train_imgs if split == 'train' else val_imgs
|
20 |
else:
|
21 |
self.images = all_images
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
self.
|
|
|
24 |
self.normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
25 |
|
26 |
def __len__(self):
|
@@ -33,30 +40,38 @@ class SaliencyDataset(torch.utils.data.Dataset):
|
|
33 |
mask_path = os.path.join(self.mask_dir, mask_filename)
|
34 |
|
35 |
img = Image.open(img_path).convert('RGB')
|
36 |
-
mask = Image.open(mask_path)
|
37 |
-
|
38 |
-
|
39 |
|
|
|
40 |
if self.split == 'train':
|
41 |
img, mask = self.apply_augmentations(img, mask)
|
42 |
|
43 |
img = transforms.ToTensor()(img)
|
44 |
img = self.normalize(img)
|
45 |
mask = transforms.ToTensor()(mask).squeeze(0)
|
|
|
46 |
return img, mask
|
47 |
|
48 |
def apply_augmentations(self, img, mask):
|
49 |
-
if random.random() > 0.5:
|
50 |
img = transforms.functional.hflip(img)
|
51 |
mask = transforms.functional.hflip(mask)
|
52 |
|
53 |
-
if random.random() > 0.5:
|
54 |
resized_crop = transforms.RandomResizedCrop(self.img_size, scale=(0.8, 1.0))
|
55 |
i, j, h, w = resized_crop.get_params(img, scale=(0.8, 1.0), ratio=(3/4, 4/3))
|
56 |
-
img = transforms.functional.resized_crop(
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
-
if random.random() > 0.5:
|
60 |
color_jitter = transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05)
|
61 |
img = color_jitter(img)
|
62 |
|
|
|
3 |
from PIL import Image
|
4 |
|
5 |
import torch
|
6 |
+
import numpy as np
|
7 |
from torchvision import transforms
|
8 |
from sklearn.model_selection import train_test_split
|
9 |
+
|
10 |
|
11 |
class SaliencyDataset(torch.utils.data.Dataset):
|
12 |
+
def __init__(self, split, img_size=512, val_split_ratio=0.05, subset_ratio=None):
|
13 |
self.img_size = img_size
|
14 |
self.split = split
|
15 |
self.image_dir, self.mask_dir = self.set_directories(split)
|
|
|
20 |
self.images = train_imgs if split == 'train' else val_imgs
|
21 |
else:
|
22 |
self.images = all_images
|
23 |
+
|
24 |
+
if subset_ratio: # subsampling
|
25 |
+
total_samples = len(self.images)
|
26 |
+
indices = np.random.choice(total_samples, int(total_samples * subset_ratio), replace=False)
|
27 |
+
self.images = [self.images[i] for i in indices]
|
28 |
|
29 |
+
self.img_resize = transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.BILINEAR)
|
30 |
+
self.mask_resize = transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.NEAREST)
|
31 |
self.normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
32 |
|
33 |
def __len__(self):
|
|
|
40 |
mask_path = os.path.join(self.mask_dir, mask_filename)
|
41 |
|
42 |
img = Image.open(img_path).convert('RGB')
|
43 |
+
mask = Image.open(mask_path)
|
44 |
+
if mask.mode != 'L': mask = mask.convert('L')
|
45 |
+
mask = mask.point(lambda p: 255 if p > 128 else 0)
|
46 |
|
47 |
+
img, mask = self.img_resize(img), self.mask_resize(mask)
|
48 |
if self.split == 'train':
|
49 |
img, mask = self.apply_augmentations(img, mask)
|
50 |
|
51 |
img = transforms.ToTensor()(img)
|
52 |
img = self.normalize(img)
|
53 |
mask = transforms.ToTensor()(mask).squeeze(0)
|
54 |
+
|
55 |
return img, mask
|
56 |
|
57 |
def apply_augmentations(self, img, mask):
|
58 |
+
if random.random() > 0.5: # horizontal flip
|
59 |
img = transforms.functional.hflip(img)
|
60 |
mask = transforms.functional.hflip(mask)
|
61 |
|
62 |
+
if random.random() > 0.5: # random resized crop
|
63 |
resized_crop = transforms.RandomResizedCrop(self.img_size, scale=(0.8, 1.0))
|
64 |
i, j, h, w = resized_crop.get_params(img, scale=(0.8, 1.0), ratio=(3/4, 4/3))
|
65 |
+
img = transforms.functional.resized_crop(
|
66 |
+
img, i, j, h, w, (self.img_size, self.img_size),
|
67 |
+
interpolation=transforms.InterpolationMode.BILINEAR
|
68 |
+
)
|
69 |
+
mask = transforms.functional.resized_crop(
|
70 |
+
mask, i, j, h, w, (self.img_size, self.img_size),
|
71 |
+
interpolation=transforms.InterpolationMode.NEAREST
|
72 |
+
)
|
73 |
|
74 |
+
if random.random() > 0.5: # color jitter
|
75 |
color_jitter = transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05)
|
76 |
img = color_jitter(img)
|
77 |
|
u2net/evaluate.py
CHANGED
@@ -3,6 +3,7 @@ from tqdm import tqdm
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
from torch.utils.data import DataLoader
|
|
|
6 |
|
7 |
from data_loader import PASCALSDataset
|
8 |
from model import U2Net
|
@@ -11,7 +12,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
11 |
print('Device:', device)
|
12 |
|
13 |
def load_model(model, model_path):
|
14 |
-
state_dict =
|
15 |
model.load_state_dict(state_dict)
|
16 |
model.eval()
|
17 |
|
@@ -28,12 +29,14 @@ def eval(model, loader, criterion):
|
|
28 |
|
29 |
|
30 |
if __name__ == '__main__':
|
31 |
-
batch_size =
|
32 |
|
|
|
|
|
33 |
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
|
34 |
model = U2Net().to(device)
|
35 |
model = nn.DataParallel(model)
|
36 |
-
load_model(model, 'results/
|
37 |
|
38 |
loader = DataLoader(PASCALSDataset(split='eval'), batch_size=batch_size, shuffle=False)
|
39 |
|
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
from torch.utils.data import DataLoader
|
6 |
+
from safetensors.torch import load_file
|
7 |
|
8 |
from data_loader import PASCALSDataset
|
9 |
from model import U2Net
|
|
|
12 |
print('Device:', device)
|
13 |
|
14 |
def load_model(model, model_path):
|
15 |
+
state_dict = load_file(model_path, device=device.type)
|
16 |
model.load_state_dict(state_dict)
|
17 |
model.eval()
|
18 |
|
|
|
29 |
|
30 |
|
31 |
if __name__ == '__main__':
|
32 |
+
batch_size = 1
|
33 |
|
34 |
+
model_type = input('Model type [b,f]: ')
|
35 |
+
model_name = 'best-u2net-duts-msra.safetensors' if model_type == 'b' else 'u2net-duts-msra.safetensors'
|
36 |
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
|
37 |
model = U2Net().to(device)
|
38 |
model = nn.DataParallel(model)
|
39 |
+
load_model(model, f'results/{model_name}')
|
40 |
|
41 |
loader = DataLoader(PASCALSDataset(split='eval'), batch_size=batch_size, shuffle=False)
|
42 |
|
u2net/model.py
CHANGED
@@ -9,17 +9,19 @@ def init_weight(layer):
|
|
9 |
|
10 |
|
11 |
class ConvBlock(nn.Module):
|
12 |
-
def __init__(self, in_channel, out_channel, dilation=1):
|
13 |
super(ConvBlock, self).__init__()
|
14 |
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=dilation, dilation=dilation)
|
15 |
self.bn = nn.BatchNorm2d(out_channel)
|
16 |
self.relu = nn.ReLU(inplace=True)
|
|
|
17 |
init_weight(self.conv)
|
18 |
|
19 |
def forward(self, x):
|
20 |
x = self.conv(x)
|
21 |
x = self.bn(x)
|
22 |
x = self.relu(x)
|
|
|
23 |
return x
|
24 |
|
25 |
|
@@ -93,7 +95,7 @@ class RSU4F(nn.Module):
|
|
93 |
|
94 |
|
95 |
class U2Net(nn.Module):
|
96 |
-
def __init__(self):
|
97 |
super(U2Net, self).__init__()
|
98 |
self.enc = nn.ModuleList([
|
99 |
RSU(L=7, C_in=3, C_out=64, M=32),
|
@@ -123,6 +125,7 @@ class U2Net(nn.Module):
|
|
123 |
|
124 |
self.lastconv = nn.Conv2d(6, 1, 1)
|
125 |
self.downsample = nn.MaxPool2d(2, stride=2)
|
|
|
126 |
|
127 |
init_weight(self.lastconv)
|
128 |
for conv in self.convs:
|
@@ -143,10 +146,10 @@ class U2Net(nn.Module):
|
|
143 |
|
144 |
side_out = []
|
145 |
for i, conv in enumerate(self.convs):
|
146 |
-
if i == 0: side_out.append(conv(dec_out[5]))
|
147 |
-
else: side_out.append(self.upsample(conv(dec_out[5-i]), side_out[0]))
|
148 |
|
149 |
side_out.append(self.lastconv(torch.cat(side_out, dim=1)))
|
150 |
|
151 |
-
# logits
|
152 |
return [s.squeeze(1) for s in side_out]
|
|
|
9 |
|
10 |
|
11 |
class ConvBlock(nn.Module):
|
12 |
+
def __init__(self, in_channel, out_channel, dilation=1, dropout_rate=0.3):
|
13 |
super(ConvBlock, self).__init__()
|
14 |
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=dilation, dilation=dilation)
|
15 |
self.bn = nn.BatchNorm2d(out_channel)
|
16 |
self.relu = nn.ReLU(inplace=True)
|
17 |
+
self.dropout = nn.Dropout2d(p=dropout_rate) # custom - add dropout layer
|
18 |
init_weight(self.conv)
|
19 |
|
20 |
def forward(self, x):
|
21 |
x = self.conv(x)
|
22 |
x = self.bn(x)
|
23 |
x = self.relu(x)
|
24 |
+
self.dropout(x)
|
25 |
return x
|
26 |
|
27 |
|
|
|
95 |
|
96 |
|
97 |
class U2Net(nn.Module):
|
98 |
+
def __init__(self, dropout_rate=0.3):
|
99 |
super(U2Net, self).__init__()
|
100 |
self.enc = nn.ModuleList([
|
101 |
RSU(L=7, C_in=3, C_out=64, M=32),
|
|
|
125 |
|
126 |
self.lastconv = nn.Conv2d(6, 1, 1)
|
127 |
self.downsample = nn.MaxPool2d(2, stride=2)
|
128 |
+
self.dropout = nn.Dropout(p=dropout_rate) # custom - add dropout layer
|
129 |
|
130 |
init_weight(self.lastconv)
|
131 |
for conv in self.convs:
|
|
|
146 |
|
147 |
side_out = []
|
148 |
for i, conv in enumerate(self.convs):
|
149 |
+
if i == 0: side_out.append(self.dropout(conv(dec_out[5])))
|
150 |
+
else: side_out.append(self.upsample(self.dropout(conv(dec_out[5-i])), side_out[0]))
|
151 |
|
152 |
side_out.append(self.lastconv(torch.cat(side_out, dim=1)))
|
153 |
|
154 |
+
# logits (no sigmoid)
|
155 |
return [s.squeeze(1) for s in side_out]
|
u2net/train.py
CHANGED
@@ -14,6 +14,20 @@ from model import U2Net
|
|
14 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
scaler = GradScaler()
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
def train_one_epoch(model, loader, criterion, optimizer):
|
18 |
model.train()
|
19 |
running_loss = 0.
|
@@ -43,47 +57,54 @@ def validate(model, loader, criterion):
|
|
43 |
avg_loss = running_loss / len(loader)
|
44 |
return avg_loss
|
45 |
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
if __name__ == '__main__':
|
48 |
batch_size = 40
|
49 |
valid_batch_size = 80
|
50 |
-
epochs =
|
51 |
|
52 |
lr = 1e-3
|
53 |
-
|
|
|
|
|
|
|
54 |
|
55 |
-
model_name = 'u2net-duts'
|
56 |
model = U2Net()
|
57 |
-
model = torch.nn.DataParallel(model.to(device))
|
58 |
-
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-
|
59 |
|
60 |
train_loader = DataLoader(
|
61 |
ConcatDataset([DUTSDataset(split='train'), MSRADataset(split='train')]),
|
62 |
batch_size=batch_size, shuffle=True, pin_memory=True,
|
63 |
-
num_workers=
|
64 |
)
|
65 |
valid_loader = DataLoader(
|
66 |
ConcatDataset([DUTSDataset(split='valid'), MSRADataset(split='valid')]),
|
67 |
batch_size=valid_batch_size, shuffle=False, pin_memory=True,
|
68 |
-
num_workers=
|
69 |
)
|
70 |
|
71 |
best_val_loss = float('inf')
|
72 |
losses = {'train': [], 'val': []}
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
best_val_loss = val_loss
|
82 |
-
save_file(model.state_dict(), f'results/best-{model_name}.safetensors')
|
83 |
-
print('Best model saved.')
|
84 |
-
|
85 |
-
print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f} (Best: {best_val_loss:.4f})')
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
14 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
scaler = GradScaler()
|
16 |
|
17 |
+
|
18 |
+
class DiceLoss(nn.Module):
|
19 |
+
def __init__(self):
|
20 |
+
super(DiceLoss, self).__init__()
|
21 |
+
|
22 |
+
def forward(self, inputs, targets, smooth=1):
|
23 |
+
inputs = torch.sigmoid(inputs)
|
24 |
+
inputs = inputs.view(-1)
|
25 |
+
targets = targets.view(-1)
|
26 |
+
intersection = (inputs * targets).sum()
|
27 |
+
dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
|
28 |
+
return 1 - dice
|
29 |
+
|
30 |
+
|
31 |
def train_one_epoch(model, loader, criterion, optimizer):
|
32 |
model.train()
|
33 |
running_loss = 0.
|
|
|
57 |
avg_loss = running_loss / len(loader)
|
58 |
return avg_loss
|
59 |
|
60 |
+
def save(model, model_name, losses):
|
61 |
+
save_file(model.state_dict(), f'results/{model_name}.safetensors')
|
62 |
+
with open('results/loss.txt', 'wb') as f:
|
63 |
+
pickle.dump(losses, f)
|
64 |
+
|
65 |
|
66 |
if __name__ == '__main__':
|
67 |
batch_size = 40
|
68 |
valid_batch_size = 80
|
69 |
+
epochs = 200
|
70 |
|
71 |
lr = 1e-3
|
72 |
+
loss_fn_bce = nn.BCEWithLogitsLoss(reduction='mean')
|
73 |
+
loss_fn_dice = DiceLoss()
|
74 |
+
alpha = 0.6
|
75 |
+
loss_fn = lambda o, m: alpha * loss_fn_bce(o, m) + (1 - alpha) * loss_fn_dice(o, m)
|
76 |
|
77 |
+
model_name = 'u2net-duts-msra'
|
78 |
model = U2Net()
|
79 |
+
model = torch.nn.parallel.DataParallel(model.to(device))
|
80 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
|
81 |
|
82 |
train_loader = DataLoader(
|
83 |
ConcatDataset([DUTSDataset(split='train'), MSRADataset(split='train')]),
|
84 |
batch_size=batch_size, shuffle=True, pin_memory=True,
|
85 |
+
num_workers=8, persistent_workers=True
|
86 |
)
|
87 |
valid_loader = DataLoader(
|
88 |
ConcatDataset([DUTSDataset(split='valid'), MSRADataset(split='valid')]),
|
89 |
batch_size=valid_batch_size, shuffle=False, pin_memory=True,
|
90 |
+
num_workers=8, persistent_workers=True
|
91 |
)
|
92 |
|
93 |
best_val_loss = float('inf')
|
94 |
losses = {'train': [], 'val': []}
|
95 |
+
|
96 |
+
# training loop
|
97 |
+
try:
|
98 |
+
for epoch in range(epochs):
|
99 |
+
train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer)
|
100 |
+
val_loss = validate(model, valid_loader, loss_fn)
|
101 |
+
losses['train'].append(train_loss)
|
102 |
+
losses['val'].append(val_loss)
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
+
if val_loss < best_val_loss:
|
105 |
+
best_val_loss = val_loss
|
106 |
+
save_file(model.state_dict(), f'results/best-{model_name}.safetensors')
|
107 |
+
|
108 |
+
print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f} (Best: {best_val_loss:.4f})')
|
109 |
+
finally:
|
110 |
+
save(model, model_name, losses)
|