Danieldu
add code
a89d9fd
raw
history blame
No virus
7.75 kB
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
from paddle.io import Dataset
import lmdb
import cv2
import string
import six
from PIL import Image
from .imaug import transform, create_operators
class LMDBDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(LMDBDataSet, self).__init__()
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card']
data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
logger.info("Initialize indexs of datasets:%s" % data_dir)
self.data_idx_order_list = self.dataset_traversal()
if self.do_shuffle:
np.random.shuffle(self.data_idx_order_list)
self.ops = create_operators(dataset_config['transforms'], global_config)
self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx",
1)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list]
def load_hierarchical_lmdb_dataset(self, data_dir):
lmdb_sets = {}
dataset_idx = 0
for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
if not dirnames:
env = lmdb.open(
dirpath,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False)
txn = env.begin(write=False)
num_samples = int(txn.get('num-samples'.encode()))
lmdb_sets[dataset_idx] = {"dirpath":dirpath, "env":env, \
"txn":txn, "num_samples":num_samples}
dataset_idx += 1
return lmdb_sets
def dataset_traversal(self):
lmdb_num = len(self.lmdb_sets)
total_sample_num = 0
for lno in range(lmdb_num):
total_sample_num += self.lmdb_sets[lno]['num_samples']
data_idx_order_list = np.zeros((total_sample_num, 2))
beg_idx = 0
for lno in range(lmdb_num):
tmp_sample_num = self.lmdb_sets[lno]['num_samples']
end_idx = beg_idx + tmp_sample_num
data_idx_order_list[beg_idx:end_idx, 0] = lno
data_idx_order_list[beg_idx:end_idx, 1] \
= list(range(tmp_sample_num))
data_idx_order_list[beg_idx:end_idx, 1] += 1
beg_idx = beg_idx + tmp_sample_num
return data_idx_order_list
def get_img_data(self, value):
"""get_img_data"""
if not value:
return None
imgdata = np.frombuffer(value, dtype='uint8')
if imgdata is None:
return None
imgori = cv2.imdecode(imgdata, 1)
if imgori is None:
return None
return imgori
def get_ext_data(self):
ext_data_num = 0
for op in self.ops:
if hasattr(op, 'ext_data_num'):
ext_data_num = getattr(op, 'ext_data_num')
break
load_data_ops = self.ops[:self.ext_op_transform_idx]
ext_data = []
while len(ext_data) < ext_data_num:
lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(
len(self))]
lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx)
sample_info = self.get_lmdb_sample_info(
self.lmdb_sets[lmdb_idx]['txn'], file_idx)
if sample_info is None:
continue
img, label = sample_info
data = {'image': img, 'label': label}
data = transform(data, load_data_ops)
if data is None:
continue
ext_data.append(data)
return ext_data
def get_lmdb_sample_info(self, txn, index):
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key)
if label is None:
return None
label = label.decode('utf-8')
img_key = 'image-%09d'.encode() % index
imgbuf = txn.get(img_key)
return imgbuf, label
def __getitem__(self, idx):
lmdb_idx, file_idx = self.data_idx_order_list[idx]
lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx)
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
file_idx)
if sample_info is None:
return self.__getitem__(np.random.randint(self.__len__()))
img, label = sample_info
data = {'image': img, 'label': label}
data['ext_data'] = self.get_ext_data()
outs = transform(data, self.ops)
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return self.data_idx_order_list.shape[0]
class LMDBDataSetSR(LMDBDataSet):
def buf2PIL(self, txn, key, type='RGB'):
imgbuf = txn.get(key)
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
im = Image.open(buf).convert(type)
return im
def str_filt(self, str_, voc_type):
alpha_dict = {
'digit': string.digits,
'lower': string.digits + string.ascii_lowercase,
'upper': string.digits + string.ascii_letters,
'all': string.digits + string.ascii_letters + string.punctuation
}
if voc_type == 'lower':
str_ = str_.lower()
for char in str_:
if char not in alpha_dict[voc_type]:
str_ = str_.replace(char, '')
return str_
def get_lmdb_sample_info(self, txn, index):
self.voc_type = 'upper'
self.max_len = 100
self.test = False
label_key = b'label-%09d' % index
word = str(txn.get(label_key).decode())
img_HR_key = b'image_hr-%09d' % index # 128*32
img_lr_key = b'image_lr-%09d' % index # 64*16
try:
img_HR = self.buf2PIL(txn, img_HR_key, 'RGB')
img_lr = self.buf2PIL(txn, img_lr_key, 'RGB')
except IOError or len(word) > self.max_len:
return self[index + 1]
label_str = self.str_filt(word, self.voc_type)
return img_HR, img_lr, label_str
def __getitem__(self, idx):
lmdb_idx, file_idx = self.data_idx_order_list[idx]
lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx)
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
file_idx)
if sample_info is None:
return self.__getitem__(np.random.randint(self.__len__()))
img_HR, img_lr, label_str = sample_info
data = {'image_hr': img_HR, 'image_lr': img_lr, 'label': label_str}
outs = transform(data, self.ops)
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs