|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import os |
|
from paddle.io import Dataset |
|
from .imaug import transform, create_operators |
|
import random |
|
|
|
|
|
class PGDataSet(Dataset): |
|
def __init__(self, config, mode, logger, seed=None): |
|
super(PGDataSet, self).__init__() |
|
|
|
self.logger = logger |
|
self.seed = seed |
|
self.mode = mode |
|
global_config = config['Global'] |
|
dataset_config = config[mode]['dataset'] |
|
loader_config = config[mode]['loader'] |
|
|
|
self.delimiter = dataset_config.get('delimiter', '\t') |
|
label_file_list = dataset_config.pop('label_file_list') |
|
data_source_num = len(label_file_list) |
|
ratio_list = dataset_config.get("ratio_list", [1.0]) |
|
if isinstance(ratio_list, (float, int)): |
|
ratio_list = [float(ratio_list)] * int(data_source_num) |
|
assert len( |
|
ratio_list |
|
) == data_source_num, "The length of ratio_list should be the same as the file_list." |
|
self.data_dir = dataset_config['data_dir'] |
|
self.do_shuffle = loader_config['shuffle'] |
|
|
|
logger.info("Initialize indexs of datasets:%s" % label_file_list) |
|
self.data_lines = self.get_image_info_list(label_file_list, ratio_list) |
|
self.data_idx_order_list = list(range(len(self.data_lines))) |
|
if mode.lower() == "train": |
|
self.shuffle_data_random() |
|
|
|
self.ops = create_operators(dataset_config['transforms'], global_config) |
|
|
|
self.need_reset = True in [x < 1 for x in ratio_list] |
|
|
|
def shuffle_data_random(self): |
|
if self.do_shuffle: |
|
random.seed(self.seed) |
|
random.shuffle(self.data_lines) |
|
return |
|
|
|
def get_image_info_list(self, file_list, ratio_list): |
|
if isinstance(file_list, str): |
|
file_list = [file_list] |
|
data_lines = [] |
|
for idx, file in enumerate(file_list): |
|
with open(file, "rb") as f: |
|
lines = f.readlines() |
|
if self.mode == "train" or ratio_list[idx] < 1.0: |
|
random.seed(self.seed) |
|
lines = random.sample(lines, |
|
round(len(lines) * ratio_list[idx])) |
|
data_lines.extend(lines) |
|
return data_lines |
|
|
|
def __getitem__(self, idx): |
|
file_idx = self.data_idx_order_list[idx] |
|
data_line = self.data_lines[file_idx] |
|
img_id = 0 |
|
try: |
|
data_line = data_line.decode('utf-8') |
|
substr = data_line.strip("\n").split(self.delimiter) |
|
file_name = substr[0] |
|
label = substr[1] |
|
img_path = os.path.join(self.data_dir, file_name) |
|
if self.mode.lower() == 'eval': |
|
try: |
|
img_id = int(data_line.split(".")[0][7:]) |
|
except: |
|
img_id = 0 |
|
data = {'img_path': img_path, 'label': label, 'img_id': img_id} |
|
if not os.path.exists(img_path): |
|
raise Exception("{} does not exist!".format(img_path)) |
|
with open(data['img_path'], 'rb') as f: |
|
img = f.read() |
|
data['image'] = img |
|
outs = transform(data, self.ops) |
|
except Exception as e: |
|
self.logger.error( |
|
"When parsing line {}, error happened with msg: {}".format( |
|
self.data_idx_order_list[idx], e)) |
|
outs = None |
|
if outs is None: |
|
return self.__getitem__(np.random.randint(self.__len__())) |
|
return outs |
|
|
|
def __len__(self): |
|
return len(self.data_idx_order_list) |
|
|