jamino30 commited on
Commit
5464cad
1 Parent(s): 8084836

Upload folder using huggingface_hub

Browse files
u2net/data_loader.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from PIL import Image
4
+
5
+ import torch
6
+ from torchvision import transforms
7
+ from sklearn.model_selection import train_test_split
8
+
9
+
10
+ class SaliencyDataset(torch.utils.data.Dataset):
11
+ def __init__(self, split, img_size=512, val_split_ratio=0.05):
12
+ self.img_size = img_size
13
+ self.split = split
14
+ self.image_dir, self.mask_dir = self.set_directories(split)
15
+
16
+ all_images = [f for f in os.listdir(self.image_dir) if f.endswith('.jpg')]
17
+ if split in ['train', 'valid']:
18
+ train_imgs, val_imgs = train_test_split(all_images, test_size=val_split_ratio, random_state=42)
19
+ self.images = train_imgs if split == 'train' else val_imgs
20
+ else:
21
+ self.images = all_images
22
+
23
+ self.resize = transforms.Resize((img_size, img_size))
24
+ self.normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
25
+
26
+ def __len__(self):
27
+ return len(self.images)
28
+
29
+ def __getitem__(self, idx):
30
+ img_filename = self.images[idx]
31
+ img_path = os.path.join(self.image_dir, img_filename)
32
+ mask_filename = img_filename.replace('.jpg', '.png')
33
+ mask_path = os.path.join(self.mask_dir, mask_filename)
34
+
35
+ img = Image.open(img_path).convert('RGB')
36
+ mask = Image.open(mask_path).convert('L')
37
+
38
+ img, mask = self.resize(img), self.resize(mask)
39
+
40
+ if self.split == 'train':
41
+ img, mask = self.apply_augmentations(img, mask)
42
+
43
+ img = transforms.ToTensor()(img)
44
+ img = self.normalize(img)
45
+ mask = transforms.ToTensor()(mask).squeeze(0)
46
+ return img, mask
47
+
48
+ def apply_augmentations(self, img, mask):
49
+ if random.random() > 0.5: # horizontal flip
50
+ img = transforms.functional.hflip(img)
51
+ mask = transforms.functional.hflip(mask)
52
+
53
+ if random.random() > 0.5: # random resized crop
54
+ resized_crop = transforms.RandomResizedCrop(self.img_size, scale=(0.8, 1.0))
55
+ i, j, h, w = resized_crop.get_params(img, scale=(0.8, 1.0), ratio=(3/4, 4/3))
56
+ img = transforms.functional.resized_crop(img, i, j, h, w, (self.img_size, self.img_size))
57
+ mask = transforms.functional.resized_crop(mask, i, j, h, w, (self.img_size, self.img_size))
58
+
59
+ if random.random() > 0.5: # color jitter
60
+ color_jitter = transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05)
61
+ img = color_jitter(img)
62
+
63
+ return img, mask
64
+
65
+
66
+ class DUTSDataset(SaliencyDataset):
67
+ def set_directories(self, split):
68
+ train_or_test = 'train' if split in ['train', 'valid'] else 'test'
69
+ image_dir = f'/data/duts_{train_or_test}_data/images'
70
+ mask_dir = f'/data/duts_{train_or_test}_data/masks'
71
+ return image_dir, mask_dir
72
+
73
+
74
+ class MSRADataset(SaliencyDataset):
75
+ def set_directories(self, split):
76
+ image_dir = '/data/msra_data/images'
77
+ mask_dir = '/data/msra_data/masks'
78
+ return image_dir, mask_dir
79
+
80
+
81
+ class PASCALSDataset(SaliencyDataset):
82
+ def set_directories(self, split):
83
+ image_dir = '/data/pascals_data/images'
84
+ mask_dir = '/data/pascals_data/masks'
85
+ return image_dir, mask_dir
u2net/download_data.sh ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ DUTS_TRAIN_URL="http://saliencydetection.net/duts/download/DUTS-TR.zip"
4
+ DUTS_TEST_URL="http://saliencydetection.net/duts/download/DUTS-TE.zip"
5
+ MSRA10K_URL="http://mftp.mmcheng.net/Data/MSRA10K_Imgs_GT.zip"
6
+ PASCALS_URL="https://cbs.ic.gatech.edu/salobj/download/salObj.zip"
7
+
8
+ dataset_dir="data_temp"
9
+ mkdir -p $dataset_dir
10
+
11
+ download_and_extract() {
12
+ url=$1
13
+ filename=$(basename $url)
14
+
15
+ echo "Downloading $filename..."
16
+ curl -L -o "$dataset_dir/$filename" $url
17
+
18
+ echo "Extracting $filename..."
19
+ unzip -q "$dataset_dir/$filename" -d $dataset_dir
20
+ rm "$dataset_dir/$filename"
21
+ }
22
+
23
+ download_duts() {
24
+ download_and_extract $DUTS_TRAIN_URL
25
+ mv "$dataset_dir/DUTS-TR/DUTS-TR-Image" "$dataset_dir/DUTS-TR/images"
26
+ mv "$dataset_dir/DUTS-TR/DUTS-TR-Mask" "$dataset_dir/DUTS-TR/masks"
27
+ mv "$dataset_dir/DUTS-TR" "$dataset_dir/duts_train_data"
28
+
29
+ download_and_extract $DUTS_TEST_URL
30
+ mv "$dataset_dir/DUTS-TE/DUTS-TE-Image" "$dataset_dir/DUTS-TE/images"
31
+ mv "$dataset_dir/DUTS-TE/DUTS-TE-Mask" "$dataset_dir/DUTS-TE/masks"
32
+ mv "$dataset_dir/DUTS-TE" "$dataset_dir/duts_test_data"
33
+ }
34
+
35
+ download_msra() {
36
+ download_and_extract $MSRA10K_URL
37
+ rm -f "$dataset_dir/Readme.txt"
38
+ mkdir -p "$dataset_dir/MSRA10K_Imgs_GT/masks"
39
+ mv "$dataset_dir/MSRA10K_Imgs_GT/Imgs/"*.png "$dataset_dir/MSRA10K_Imgs_GT/masks"
40
+ mv "$dataset_dir/MSRA10K_Imgs_GT/Imgs" "$dataset_dir/MSRA10K_Imgs_GT/images"
41
+ mv "$dataset_dir/MSRA10K_Imgs_GT" "$dataset_dir/msra_data"
42
+ }
43
+
44
+ download_pascals() {
45
+ download_and_extract $PASCALS_URL
46
+ rm -rf "$dataset_dir/algmaps" "$dataset_dir/benchmark" "$dataset_dir/code" "$dataset_dir/results" \
47
+ "$dataset_dir/readme.pdf" "$dataset_dir/tips_for_matlab.txt" "$dataset_dir/datasets/fixations" \
48
+ "$dataset_dir/datasets/segments" "$dataset_dir/datasets/imgs/bruce" "$dataset_dir/datasets/imgs/cerf" \
49
+ "$dataset_dir/datasets/imgs/ft" "$dataset_dir/datasets/imgs/judd" "$dataset_dir/datasets/imgs/pascal" \
50
+ "$dataset_dir/datasets/masks/bruce" "$dataset_dir/datasets/masks/ft" "$dataset_dir/datasets/masks/pascal"
51
+ mv "$dataset_dir/datasets/imgs/imgsal"/* "$dataset_dir/datasets/imgs"
52
+ mv "$dataset_dir/datasets/masks/imgsal"/* "$dataset_dir/datasets/masks"
53
+ rm -rf "$dataset_dir/datasets/imgs/imgsal" "$dataset_dir/datasets/imgs/Thumbs.db" "$dataset_dir/datasets/masks/imgsal"
54
+ mv "$dataset_dir/datasets/imgs" "$dataset_dir/datasets/images"
55
+ mv "$dataset_dir/datasets" "$dataset_dir/pascals_data"
56
+ }
57
+
58
+ usage() {
59
+ echo "Usage: $0 [-d] [-m] [-p]"
60
+ echo " -d Download DUTS dataset (train and test)"
61
+ echo " -m Download MSRA10K dataset"
62
+ echo " -p Download Pascal-S dataset"
63
+ echo "If no options are provided, all datasets will be downloaded."
64
+ exit 1
65
+ }
66
+
67
+ all=false
68
+ while getopts "dmp" opt; do
69
+ case $opt in
70
+ d)
71
+ download_duts
72
+ ;;
73
+ m)
74
+ download_msra
75
+ ;;
76
+ p)
77
+ download_pascals
78
+ ;;
79
+ *)
80
+ usage
81
+ ;;
82
+ esac
83
+ done
84
+
85
+ # Check if no options were provided
86
+ if [ $OPTIND -eq 1 ]; then
87
+ echo "No options provided; downloading all datasets."
88
+ download_duts
89
+ download_msra
90
+ download_pascals
91
+ fi
u2net/evaluate.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.data import DataLoader
6
+
7
+ from data_loader import PASCALSDataset
8
+ from model import U2Net
9
+
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+ print('Device:', device)
12
+
13
+ def load_model(model, model_path):
14
+ state_dict = torch.load(model_path, map_location=device, weights_only=True)
15
+ model.load_state_dict(state_dict)
16
+ model.eval()
17
+
18
+ def eval(model, loader, criterion):
19
+ model.eval()
20
+ running_loss = 0.
21
+ with torch.no_grad():
22
+ for images, masks in tqdm(loader, desc='Testing'):
23
+ images, masks = images.to(device), masks.to(device)
24
+ outputs = model(images)
25
+ loss = sum([criterion(output, masks) for output in outputs])
26
+ running_loss += loss.item()
27
+ return running_loss / len(loader)
28
+
29
+
30
+ if __name__ == '__main__':
31
+ batch_size = 40
32
+
33
+ loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
34
+ model = U2Net().to(device)
35
+ model = nn.DataParallel(model)
36
+ load_model(model, 'results/inter-u2net-duts.pt')
37
+
38
+ loader = DataLoader(PASCALSDataset(split='eval'), batch_size=batch_size, shuffle=False)
39
+
40
+ loss = eval(model, loader, loss_fn)
41
+ print(f'Loss: {loss:.4f}')
u2net/inference.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+ import numpy as np
5
+ from PIL import Image
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib.gridspec import GridSpec
8
+
9
+ from model import U2Net
10
+
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+
13
+ def preprocess_image(image_path):
14
+ img = Image.open(image_path).convert('RGB')
15
+ preprocess = transforms.Compose([
16
+ transforms.Resize((512, 512)),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
19
+ ])
20
+ img = preprocess(img).unsqueeze(0).to(device)
21
+ return img
22
+
23
+ def run_inference(model, image_path, threshold=0.5):
24
+ input_img = preprocess_image(image_path)
25
+ with torch.no_grad():
26
+ d1, *_ = model(input_img)
27
+ pred = torch.sigmoid(d1)
28
+ pred = pred[0, :, :].cpu().numpy()
29
+
30
+ pred = (pred - pred.min()) / (pred.max() - pred.min())
31
+ if threshold is not None:
32
+ pred = (pred > threshold).astype(np.uint8) * 255
33
+ else:
34
+ pred = (pred * 255).astype(np.uint8)
35
+ return pred
36
+
37
+ def overlay_segmentation(original_image, binary_mask, alpha=0.5):
38
+ original_image = Image.open(original_image).convert('RGB').resize((512, 512), Image.BILINEAR)
39
+ original_image_np = np.array(original_image)
40
+ overlay = np.zeros_like(original_image_np)
41
+ overlay[:, :, 0] = binary_mask
42
+ overlay_image = (1 - alpha) * original_image_np + alpha * overlay
43
+ overlay_image = overlay_image.astype(np.uint8)
44
+ return overlay_image
45
+
46
+
47
+ if __name__ == '__main__':
48
+ # ---
49
+ model_path = 'results/inter-u2net-duts.pt'
50
+ image_path = 'images/ladies.jpg'
51
+ # ---
52
+ model = U2Net().to(device)
53
+ model = nn.DataParallel(model)
54
+ state_dict = torch.load(model_path, map_location=device, weights_only=True)
55
+ model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
56
+ model.eval()
57
+
58
+ mask = run_inference(model, image_path, threshold=None)
59
+ mask_with_threshold = run_inference(model, image_path, threshold=0.7)
60
+
61
+ fig = plt.figure(figsize=(10, 10))
62
+ gs = GridSpec(2, 2, figure=fig, wspace=0, hspace=0)
63
+
64
+ images = [
65
+ Image.open(image_path).resize((512, 512)),
66
+ mask,
67
+ overlay_segmentation(image_path, mask_with_threshold),
68
+ mask_with_threshold
69
+ ]
70
+
71
+ for i, img in enumerate(images):
72
+ ax = fig.add_subplot(gs[i // 2, i % 2])
73
+ ax.imshow(img, cmap='gray' if i % 2 != 0 else None)
74
+ ax.axis('off')
75
+
76
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
77
+ plt.savefig('inference-output.jpg', format='jpg', bbox_inches='tight', pad_inches=0)
u2net/model.py CHANGED
@@ -29,7 +29,7 @@ class RSU(nn.Module):
29
  self.conv = ConvBlock(C_in, C_out)
30
 
31
  self.enc = nn.ModuleList([ConvBlock(C_out, M)])
32
- for i in range(L-2):
33
  self.enc.append(ConvBlock(M, M))
34
 
35
  self.mid = ConvBlock(M, M, dilation=2)
@@ -148,4 +148,5 @@ class U2Net(nn.Module):
148
 
149
  side_out.append(self.lastconv(torch.cat(side_out, dim=1)))
150
 
151
- return [torch.sigmoid(s.squeeze(1)) for s in side_out]
 
 
29
  self.conv = ConvBlock(C_in, C_out)
30
 
31
  self.enc = nn.ModuleList([ConvBlock(C_out, M)])
32
+ for _ in range(L-2):
33
  self.enc.append(ConvBlock(M, M))
34
 
35
  self.mid = ConvBlock(M, M, dilation=2)
 
148
 
149
  side_out.append(self.lastconv(torch.cat(side_out, dim=1)))
150
 
151
+ # logits
152
+ return [s.squeeze(1) for s in side_out]
u2net/train.py CHANGED
@@ -1 +1,85 @@
1
- # for training u2net
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from tqdm import tqdm
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.utils.data import DataLoader, ConcatDataset
8
+ from torch.amp import autocast, GradScaler
9
+
10
+ from data_loader import DUTSDataset, MSRADataset
11
+ from model import U2Net
12
+
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ scaler = GradScaler()
15
+
16
+ def train_one_epoch(model, loader, criterion, optimizer):
17
+ model.train()
18
+ running_loss = 0.
19
+ for images, masks in tqdm(loader, desc='Training', leave=False):
20
+ images, masks = images.to(device, non_blocking=True), masks.to(device, non_blocking=True)
21
+
22
+ optimizer.zero_grad()
23
+ with autocast(device_type='cuda'):
24
+ outputs = model(images)
25
+ loss = sum([criterion(output, masks) for output in outputs])
26
+ scaler.scale(loss).backward()
27
+ scaler.step(optimizer)
28
+ scaler.update()
29
+
30
+ running_loss += loss.item()
31
+ return running_loss / len(loader)
32
+
33
+ def validate(model, loader, criterion):
34
+ model.eval()
35
+ running_loss = 0.
36
+ with torch.no_grad():
37
+ for images, masks in tqdm(loader, desc='Validating', leave=False):
38
+ images, masks = images.to(device, non_blocking=True), masks.to(device, non_blocking=True)
39
+ outputs = model(images)
40
+ loss = sum([criterion(output, masks) for output in outputs])
41
+ running_loss += loss.item()
42
+ avg_loss = running_loss / len(loader)
43
+ return avg_loss
44
+
45
+
46
+ if __name__ == '__main__':
47
+ batch_size = 40
48
+ valid_batch_size = 80
49
+ epochs = 100
50
+
51
+ lr = 1e-4
52
+ loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
53
+
54
+ model_name = 'u2net-duts'
55
+ model = U2Net()
56
+ model = torch.nn.DataParallel(model.to(device))
57
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
58
+
59
+ train_loader = DataLoader(
60
+ ConcatDataset([DUTSDataset(split='train'), MSRADataset(split='train')]),
61
+ batch_size=batch_size, shuffle=True, pin_memory=True,
62
+ num_workers=16, persistent_workers=True
63
+ )
64
+ valid_loader = DataLoader(
65
+ ConcatDataset([DUTSDataset(split='valid'), MSRADataset(split='valid')]),
66
+ batch_size=valid_batch_size, shuffle=False, pin_memory=True,
67
+ num_workers=16, persistent_workers=True
68
+ )
69
+
70
+ losses = {'train': [], 'val': []}
71
+ for epoch in tqdm(range(epochs), desc='Epochs'):
72
+ torch.cuda.empty_cache()
73
+ train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer)
74
+ val_loss = validate(model, valid_loader, loss_fn)
75
+ losses['train'].append(train_loss)
76
+ losses['val'].append(val_loss)
77
+
78
+ if (epoch + 1) % 10 == 0:
79
+ torch.save(model.state_dict(), f'results/inter-{model_name}.pt')
80
+
81
+ print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
82
+
83
+ torch.save(model.state_dict(), f'results/{model_name}.pt')
84
+ with open('results/loss.txt', 'wb') as f:
85
+ pickle.dump(losses, f)