Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- u2net/data_loader.py +85 -0
- u2net/download_data.sh +91 -0
- u2net/evaluate.py +41 -0
- u2net/inference.py +77 -0
- u2net/model.py +3 -2
- u2net/train.py +85 -1
u2net/data_loader.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torchvision import transforms
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
+
|
9 |
+
|
10 |
+
class SaliencyDataset(torch.utils.data.Dataset):
|
11 |
+
def __init__(self, split, img_size=512, val_split_ratio=0.05):
|
12 |
+
self.img_size = img_size
|
13 |
+
self.split = split
|
14 |
+
self.image_dir, self.mask_dir = self.set_directories(split)
|
15 |
+
|
16 |
+
all_images = [f for f in os.listdir(self.image_dir) if f.endswith('.jpg')]
|
17 |
+
if split in ['train', 'valid']:
|
18 |
+
train_imgs, val_imgs = train_test_split(all_images, test_size=val_split_ratio, random_state=42)
|
19 |
+
self.images = train_imgs if split == 'train' else val_imgs
|
20 |
+
else:
|
21 |
+
self.images = all_images
|
22 |
+
|
23 |
+
self.resize = transforms.Resize((img_size, img_size))
|
24 |
+
self.normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
25 |
+
|
26 |
+
def __len__(self):
|
27 |
+
return len(self.images)
|
28 |
+
|
29 |
+
def __getitem__(self, idx):
|
30 |
+
img_filename = self.images[idx]
|
31 |
+
img_path = os.path.join(self.image_dir, img_filename)
|
32 |
+
mask_filename = img_filename.replace('.jpg', '.png')
|
33 |
+
mask_path = os.path.join(self.mask_dir, mask_filename)
|
34 |
+
|
35 |
+
img = Image.open(img_path).convert('RGB')
|
36 |
+
mask = Image.open(mask_path).convert('L')
|
37 |
+
|
38 |
+
img, mask = self.resize(img), self.resize(mask)
|
39 |
+
|
40 |
+
if self.split == 'train':
|
41 |
+
img, mask = self.apply_augmentations(img, mask)
|
42 |
+
|
43 |
+
img = transforms.ToTensor()(img)
|
44 |
+
img = self.normalize(img)
|
45 |
+
mask = transforms.ToTensor()(mask).squeeze(0)
|
46 |
+
return img, mask
|
47 |
+
|
48 |
+
def apply_augmentations(self, img, mask):
|
49 |
+
if random.random() > 0.5: # horizontal flip
|
50 |
+
img = transforms.functional.hflip(img)
|
51 |
+
mask = transforms.functional.hflip(mask)
|
52 |
+
|
53 |
+
if random.random() > 0.5: # random resized crop
|
54 |
+
resized_crop = transforms.RandomResizedCrop(self.img_size, scale=(0.8, 1.0))
|
55 |
+
i, j, h, w = resized_crop.get_params(img, scale=(0.8, 1.0), ratio=(3/4, 4/3))
|
56 |
+
img = transforms.functional.resized_crop(img, i, j, h, w, (self.img_size, self.img_size))
|
57 |
+
mask = transforms.functional.resized_crop(mask, i, j, h, w, (self.img_size, self.img_size))
|
58 |
+
|
59 |
+
if random.random() > 0.5: # color jitter
|
60 |
+
color_jitter = transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05)
|
61 |
+
img = color_jitter(img)
|
62 |
+
|
63 |
+
return img, mask
|
64 |
+
|
65 |
+
|
66 |
+
class DUTSDataset(SaliencyDataset):
|
67 |
+
def set_directories(self, split):
|
68 |
+
train_or_test = 'train' if split in ['train', 'valid'] else 'test'
|
69 |
+
image_dir = f'/data/duts_{train_or_test}_data/images'
|
70 |
+
mask_dir = f'/data/duts_{train_or_test}_data/masks'
|
71 |
+
return image_dir, mask_dir
|
72 |
+
|
73 |
+
|
74 |
+
class MSRADataset(SaliencyDataset):
|
75 |
+
def set_directories(self, split):
|
76 |
+
image_dir = '/data/msra_data/images'
|
77 |
+
mask_dir = '/data/msra_data/masks'
|
78 |
+
return image_dir, mask_dir
|
79 |
+
|
80 |
+
|
81 |
+
class PASCALSDataset(SaliencyDataset):
|
82 |
+
def set_directories(self, split):
|
83 |
+
image_dir = '/data/pascals_data/images'
|
84 |
+
mask_dir = '/data/pascals_data/masks'
|
85 |
+
return image_dir, mask_dir
|
u2net/download_data.sh
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
DUTS_TRAIN_URL="http://saliencydetection.net/duts/download/DUTS-TR.zip"
|
4 |
+
DUTS_TEST_URL="http://saliencydetection.net/duts/download/DUTS-TE.zip"
|
5 |
+
MSRA10K_URL="http://mftp.mmcheng.net/Data/MSRA10K_Imgs_GT.zip"
|
6 |
+
PASCALS_URL="https://cbs.ic.gatech.edu/salobj/download/salObj.zip"
|
7 |
+
|
8 |
+
dataset_dir="data_temp"
|
9 |
+
mkdir -p $dataset_dir
|
10 |
+
|
11 |
+
download_and_extract() {
|
12 |
+
url=$1
|
13 |
+
filename=$(basename $url)
|
14 |
+
|
15 |
+
echo "Downloading $filename..."
|
16 |
+
curl -L -o "$dataset_dir/$filename" $url
|
17 |
+
|
18 |
+
echo "Extracting $filename..."
|
19 |
+
unzip -q "$dataset_dir/$filename" -d $dataset_dir
|
20 |
+
rm "$dataset_dir/$filename"
|
21 |
+
}
|
22 |
+
|
23 |
+
download_duts() {
|
24 |
+
download_and_extract $DUTS_TRAIN_URL
|
25 |
+
mv "$dataset_dir/DUTS-TR/DUTS-TR-Image" "$dataset_dir/DUTS-TR/images"
|
26 |
+
mv "$dataset_dir/DUTS-TR/DUTS-TR-Mask" "$dataset_dir/DUTS-TR/masks"
|
27 |
+
mv "$dataset_dir/DUTS-TR" "$dataset_dir/duts_train_data"
|
28 |
+
|
29 |
+
download_and_extract $DUTS_TEST_URL
|
30 |
+
mv "$dataset_dir/DUTS-TE/DUTS-TE-Image" "$dataset_dir/DUTS-TE/images"
|
31 |
+
mv "$dataset_dir/DUTS-TE/DUTS-TE-Mask" "$dataset_dir/DUTS-TE/masks"
|
32 |
+
mv "$dataset_dir/DUTS-TE" "$dataset_dir/duts_test_data"
|
33 |
+
}
|
34 |
+
|
35 |
+
download_msra() {
|
36 |
+
download_and_extract $MSRA10K_URL
|
37 |
+
rm -f "$dataset_dir/Readme.txt"
|
38 |
+
mkdir -p "$dataset_dir/MSRA10K_Imgs_GT/masks"
|
39 |
+
mv "$dataset_dir/MSRA10K_Imgs_GT/Imgs/"*.png "$dataset_dir/MSRA10K_Imgs_GT/masks"
|
40 |
+
mv "$dataset_dir/MSRA10K_Imgs_GT/Imgs" "$dataset_dir/MSRA10K_Imgs_GT/images"
|
41 |
+
mv "$dataset_dir/MSRA10K_Imgs_GT" "$dataset_dir/msra_data"
|
42 |
+
}
|
43 |
+
|
44 |
+
download_pascals() {
|
45 |
+
download_and_extract $PASCALS_URL
|
46 |
+
rm -rf "$dataset_dir/algmaps" "$dataset_dir/benchmark" "$dataset_dir/code" "$dataset_dir/results" \
|
47 |
+
"$dataset_dir/readme.pdf" "$dataset_dir/tips_for_matlab.txt" "$dataset_dir/datasets/fixations" \
|
48 |
+
"$dataset_dir/datasets/segments" "$dataset_dir/datasets/imgs/bruce" "$dataset_dir/datasets/imgs/cerf" \
|
49 |
+
"$dataset_dir/datasets/imgs/ft" "$dataset_dir/datasets/imgs/judd" "$dataset_dir/datasets/imgs/pascal" \
|
50 |
+
"$dataset_dir/datasets/masks/bruce" "$dataset_dir/datasets/masks/ft" "$dataset_dir/datasets/masks/pascal"
|
51 |
+
mv "$dataset_dir/datasets/imgs/imgsal"/* "$dataset_dir/datasets/imgs"
|
52 |
+
mv "$dataset_dir/datasets/masks/imgsal"/* "$dataset_dir/datasets/masks"
|
53 |
+
rm -rf "$dataset_dir/datasets/imgs/imgsal" "$dataset_dir/datasets/imgs/Thumbs.db" "$dataset_dir/datasets/masks/imgsal"
|
54 |
+
mv "$dataset_dir/datasets/imgs" "$dataset_dir/datasets/images"
|
55 |
+
mv "$dataset_dir/datasets" "$dataset_dir/pascals_data"
|
56 |
+
}
|
57 |
+
|
58 |
+
usage() {
|
59 |
+
echo "Usage: $0 [-d] [-m] [-p]"
|
60 |
+
echo " -d Download DUTS dataset (train and test)"
|
61 |
+
echo " -m Download MSRA10K dataset"
|
62 |
+
echo " -p Download Pascal-S dataset"
|
63 |
+
echo "If no options are provided, all datasets will be downloaded."
|
64 |
+
exit 1
|
65 |
+
}
|
66 |
+
|
67 |
+
all=false
|
68 |
+
while getopts "dmp" opt; do
|
69 |
+
case $opt in
|
70 |
+
d)
|
71 |
+
download_duts
|
72 |
+
;;
|
73 |
+
m)
|
74 |
+
download_msra
|
75 |
+
;;
|
76 |
+
p)
|
77 |
+
download_pascals
|
78 |
+
;;
|
79 |
+
*)
|
80 |
+
usage
|
81 |
+
;;
|
82 |
+
esac
|
83 |
+
done
|
84 |
+
|
85 |
+
# Check if no options were provided
|
86 |
+
if [ $OPTIND -eq 1 ]; then
|
87 |
+
echo "No options provided; downloading all datasets."
|
88 |
+
download_duts
|
89 |
+
download_msra
|
90 |
+
download_pascals
|
91 |
+
fi
|
u2net/evaluate.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
|
7 |
+
from data_loader import PASCALSDataset
|
8 |
+
from model import U2Net
|
9 |
+
|
10 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
11 |
+
print('Device:', device)
|
12 |
+
|
13 |
+
def load_model(model, model_path):
|
14 |
+
state_dict = torch.load(model_path, map_location=device, weights_only=True)
|
15 |
+
model.load_state_dict(state_dict)
|
16 |
+
model.eval()
|
17 |
+
|
18 |
+
def eval(model, loader, criterion):
|
19 |
+
model.eval()
|
20 |
+
running_loss = 0.
|
21 |
+
with torch.no_grad():
|
22 |
+
for images, masks in tqdm(loader, desc='Testing'):
|
23 |
+
images, masks = images.to(device), masks.to(device)
|
24 |
+
outputs = model(images)
|
25 |
+
loss = sum([criterion(output, masks) for output in outputs])
|
26 |
+
running_loss += loss.item()
|
27 |
+
return running_loss / len(loader)
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == '__main__':
|
31 |
+
batch_size = 40
|
32 |
+
|
33 |
+
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
|
34 |
+
model = U2Net().to(device)
|
35 |
+
model = nn.DataParallel(model)
|
36 |
+
load_model(model, 'results/inter-u2net-duts.pt')
|
37 |
+
|
38 |
+
loader = DataLoader(PASCALSDataset(split='eval'), batch_size=batch_size, shuffle=False)
|
39 |
+
|
40 |
+
loss = eval(model, loader, loss_fn)
|
41 |
+
print(f'Loss: {loss:.4f}')
|
u2net/inference.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchvision import transforms
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from matplotlib.gridspec import GridSpec
|
8 |
+
|
9 |
+
from model import U2Net
|
10 |
+
|
11 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
+
|
13 |
+
def preprocess_image(image_path):
|
14 |
+
img = Image.open(image_path).convert('RGB')
|
15 |
+
preprocess = transforms.Compose([
|
16 |
+
transforms.Resize((512, 512)),
|
17 |
+
transforms.ToTensor(),
|
18 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
19 |
+
])
|
20 |
+
img = preprocess(img).unsqueeze(0).to(device)
|
21 |
+
return img
|
22 |
+
|
23 |
+
def run_inference(model, image_path, threshold=0.5):
|
24 |
+
input_img = preprocess_image(image_path)
|
25 |
+
with torch.no_grad():
|
26 |
+
d1, *_ = model(input_img)
|
27 |
+
pred = torch.sigmoid(d1)
|
28 |
+
pred = pred[0, :, :].cpu().numpy()
|
29 |
+
|
30 |
+
pred = (pred - pred.min()) / (pred.max() - pred.min())
|
31 |
+
if threshold is not None:
|
32 |
+
pred = (pred > threshold).astype(np.uint8) * 255
|
33 |
+
else:
|
34 |
+
pred = (pred * 255).astype(np.uint8)
|
35 |
+
return pred
|
36 |
+
|
37 |
+
def overlay_segmentation(original_image, binary_mask, alpha=0.5):
|
38 |
+
original_image = Image.open(original_image).convert('RGB').resize((512, 512), Image.BILINEAR)
|
39 |
+
original_image_np = np.array(original_image)
|
40 |
+
overlay = np.zeros_like(original_image_np)
|
41 |
+
overlay[:, :, 0] = binary_mask
|
42 |
+
overlay_image = (1 - alpha) * original_image_np + alpha * overlay
|
43 |
+
overlay_image = overlay_image.astype(np.uint8)
|
44 |
+
return overlay_image
|
45 |
+
|
46 |
+
|
47 |
+
if __name__ == '__main__':
|
48 |
+
# ---
|
49 |
+
model_path = 'results/inter-u2net-duts.pt'
|
50 |
+
image_path = 'images/ladies.jpg'
|
51 |
+
# ---
|
52 |
+
model = U2Net().to(device)
|
53 |
+
model = nn.DataParallel(model)
|
54 |
+
state_dict = torch.load(model_path, map_location=device, weights_only=True)
|
55 |
+
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
|
56 |
+
model.eval()
|
57 |
+
|
58 |
+
mask = run_inference(model, image_path, threshold=None)
|
59 |
+
mask_with_threshold = run_inference(model, image_path, threshold=0.7)
|
60 |
+
|
61 |
+
fig = plt.figure(figsize=(10, 10))
|
62 |
+
gs = GridSpec(2, 2, figure=fig, wspace=0, hspace=0)
|
63 |
+
|
64 |
+
images = [
|
65 |
+
Image.open(image_path).resize((512, 512)),
|
66 |
+
mask,
|
67 |
+
overlay_segmentation(image_path, mask_with_threshold),
|
68 |
+
mask_with_threshold
|
69 |
+
]
|
70 |
+
|
71 |
+
for i, img in enumerate(images):
|
72 |
+
ax = fig.add_subplot(gs[i // 2, i % 2])
|
73 |
+
ax.imshow(img, cmap='gray' if i % 2 != 0 else None)
|
74 |
+
ax.axis('off')
|
75 |
+
|
76 |
+
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
|
77 |
+
plt.savefig('inference-output.jpg', format='jpg', bbox_inches='tight', pad_inches=0)
|
u2net/model.py
CHANGED
@@ -29,7 +29,7 @@ class RSU(nn.Module):
|
|
29 |
self.conv = ConvBlock(C_in, C_out)
|
30 |
|
31 |
self.enc = nn.ModuleList([ConvBlock(C_out, M)])
|
32 |
-
for
|
33 |
self.enc.append(ConvBlock(M, M))
|
34 |
|
35 |
self.mid = ConvBlock(M, M, dilation=2)
|
@@ -148,4 +148,5 @@ class U2Net(nn.Module):
|
|
148 |
|
149 |
side_out.append(self.lastconv(torch.cat(side_out, dim=1)))
|
150 |
|
151 |
-
|
|
|
|
29 |
self.conv = ConvBlock(C_in, C_out)
|
30 |
|
31 |
self.enc = nn.ModuleList([ConvBlock(C_out, M)])
|
32 |
+
for _ in range(L-2):
|
33 |
self.enc.append(ConvBlock(M, M))
|
34 |
|
35 |
self.mid = ConvBlock(M, M, dilation=2)
|
|
|
148 |
|
149 |
side_out.append(self.lastconv(torch.cat(side_out, dim=1)))
|
150 |
|
151 |
+
# logits
|
152 |
+
return [s.squeeze(1) for s in side_out]
|
u2net/train.py
CHANGED
@@ -1 +1,85 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from tqdm import tqdm
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.optim as optim
|
7 |
+
from torch.utils.data import DataLoader, ConcatDataset
|
8 |
+
from torch.amp import autocast, GradScaler
|
9 |
+
|
10 |
+
from data_loader import DUTSDataset, MSRADataset
|
11 |
+
from model import U2Net
|
12 |
+
|
13 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
+
scaler = GradScaler()
|
15 |
+
|
16 |
+
def train_one_epoch(model, loader, criterion, optimizer):
|
17 |
+
model.train()
|
18 |
+
running_loss = 0.
|
19 |
+
for images, masks in tqdm(loader, desc='Training', leave=False):
|
20 |
+
images, masks = images.to(device, non_blocking=True), masks.to(device, non_blocking=True)
|
21 |
+
|
22 |
+
optimizer.zero_grad()
|
23 |
+
with autocast(device_type='cuda'):
|
24 |
+
outputs = model(images)
|
25 |
+
loss = sum([criterion(output, masks) for output in outputs])
|
26 |
+
scaler.scale(loss).backward()
|
27 |
+
scaler.step(optimizer)
|
28 |
+
scaler.update()
|
29 |
+
|
30 |
+
running_loss += loss.item()
|
31 |
+
return running_loss / len(loader)
|
32 |
+
|
33 |
+
def validate(model, loader, criterion):
|
34 |
+
model.eval()
|
35 |
+
running_loss = 0.
|
36 |
+
with torch.no_grad():
|
37 |
+
for images, masks in tqdm(loader, desc='Validating', leave=False):
|
38 |
+
images, masks = images.to(device, non_blocking=True), masks.to(device, non_blocking=True)
|
39 |
+
outputs = model(images)
|
40 |
+
loss = sum([criterion(output, masks) for output in outputs])
|
41 |
+
running_loss += loss.item()
|
42 |
+
avg_loss = running_loss / len(loader)
|
43 |
+
return avg_loss
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == '__main__':
|
47 |
+
batch_size = 40
|
48 |
+
valid_batch_size = 80
|
49 |
+
epochs = 100
|
50 |
+
|
51 |
+
lr = 1e-4
|
52 |
+
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
|
53 |
+
|
54 |
+
model_name = 'u2net-duts'
|
55 |
+
model = U2Net()
|
56 |
+
model = torch.nn.DataParallel(model.to(device))
|
57 |
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
|
58 |
+
|
59 |
+
train_loader = DataLoader(
|
60 |
+
ConcatDataset([DUTSDataset(split='train'), MSRADataset(split='train')]),
|
61 |
+
batch_size=batch_size, shuffle=True, pin_memory=True,
|
62 |
+
num_workers=16, persistent_workers=True
|
63 |
+
)
|
64 |
+
valid_loader = DataLoader(
|
65 |
+
ConcatDataset([DUTSDataset(split='valid'), MSRADataset(split='valid')]),
|
66 |
+
batch_size=valid_batch_size, shuffle=False, pin_memory=True,
|
67 |
+
num_workers=16, persistent_workers=True
|
68 |
+
)
|
69 |
+
|
70 |
+
losses = {'train': [], 'val': []}
|
71 |
+
for epoch in tqdm(range(epochs), desc='Epochs'):
|
72 |
+
torch.cuda.empty_cache()
|
73 |
+
train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer)
|
74 |
+
val_loss = validate(model, valid_loader, loss_fn)
|
75 |
+
losses['train'].append(train_loss)
|
76 |
+
losses['val'].append(val_loss)
|
77 |
+
|
78 |
+
if (epoch + 1) % 10 == 0:
|
79 |
+
torch.save(model.state_dict(), f'results/inter-{model_name}.pt')
|
80 |
+
|
81 |
+
print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
|
82 |
+
|
83 |
+
torch.save(model.state_dict(), f'results/{model_name}.pt')
|
84 |
+
with open('results/loss.txt', 'wb') as f:
|
85 |
+
pickle.dump(losses, f)
|