Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,007 Bytes
9669aec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import os.path as osp
import os
from PIL import Image
import numpy as np
import json
import cv2
from transform import *
class FaceMask(Dataset):
def __init__(self, rootpth, cropsize=(640, 480), mode='train', *args, **kwargs):
super(FaceMask, self).__init__(*args, **kwargs)
assert mode in ('train', 'val', 'test')
self.mode = mode
self.ignore_lb = 255
self.rootpth = rootpth
self.imgs = os.listdir(os.path.join(self.rootpth, 'CelebA-HQ-img'))
# pre-processing
self.to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
self.trans_train = Compose([
ColorJitter(
brightness=0.5,
contrast=0.5,
saturation=0.5),
HorizontalFlip(),
RandomScale((0.75, 1.0, 1.25, 1.5, 1.75, 2.0)),
RandomCrop(cropsize)
])
def __getitem__(self, idx):
impth = self.imgs[idx]
img = Image.open(osp.join(self.rootpth, 'CelebA-HQ-img', impth))
img = img.resize((512, 512), Image.BILINEAR)
label = Image.open(osp.join(self.rootpth, 'mask', impth[:-3]+'png')).convert('P')
# print(np.unique(np.array(label)))
if self.mode == 'train':
im_lb = dict(im=img, lb=label)
im_lb = self.trans_train(im_lb)
img, label = im_lb['im'], im_lb['lb']
img = self.to_tensor(img)
label = np.array(label).astype(np.int64)[np.newaxis, :]
return img, label
def __len__(self):
return len(self.imgs)
if __name__ == "__main__":
face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img'
face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno'
mask_path = '/home/zll/data/CelebAMask-HQ/mask'
counter = 0
total = 0
for i in range(15):
# files = os.listdir(osp.join(face_sep_mask, str(i)))
atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
for j in range(i*2000, (i+1)*2000):
mask = np.zeros((512, 512))
for l, att in enumerate(atts, 1):
total += 1
file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png'])
path = osp.join(face_sep_mask, str(i), file_name)
if os.path.exists(path):
counter += 1
sep_mask = np.array(Image.open(path).convert('P'))
# print(np.unique(sep_mask))
mask[sep_mask == 225] = l
cv2.imwrite('{}/{}.png'.format(mask_path, j), mask)
print(j)
print(counter, total)
|