import os import numpy as np import torch import random from PIL import Image, ImageDraw, ImageFont import pickle from config.GlobalVariables import * np.random.seed(0) class DataLoader(): def __init__(self, num_writer=2, num_samples=5, divider=10.0, datadir='./data/writers'): self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.num_writer = num_writer self.num_samples = num_samples self.divider = divider self.datadir = datadir print ('self.datadir : ', self.datadir) self.total_writers = len([name for name in os.listdir(datadir)]) def next_batch(self, TYPE='TRAIN', uid=-1, tids=[]): all_sentence_level_stroke_in = [] all_sentence_level_stroke_out = [] all_sentence_level_stroke_length = [] all_sentence_level_term = [] all_sentence_level_char = [] all_sentence_level_char_length = [] all_word_level_stroke_in = [] all_word_level_stroke_out = [] all_word_level_stroke_length = [] all_word_level_term = [] all_word_level_char = [] all_word_level_char_length = [] all_segment_level_stroke_in = [] all_segment_level_stroke_out = [] all_segment_level_stroke_length = [] all_segment_level_term = [] all_segment_level_char = [] all_segment_level_char_length = [] while len(all_sentence_level_stroke_in) < self.num_writer: if uid < 0: if TYPE == 'TRAIN': if self.datadir == './data/NEW_writers' or self.datadir == './data/writers': uid = np.random.choice([i for i in range(150)]) else: if self.device == 'cpu': uid = np.random.choice([i for i in range(20)]) else: uid = np.random.choice([i for i in range(294)]) else: uid = np.random.choice([i for i in range(150,170)]) total_texts = len([name for name in os.listdir(self.datadir+'/'+str(uid))]) if len(tids) == 0: tids = random.sample([i for i in range(total_texts)], self.num_samples) user_sentence_level_stroke_in = [] user_sentence_level_stroke_out = [] user_sentence_level_stroke_length = [] user_sentence_level_term = [] user_sentence_level_char = [] user_sentence_level_char_length = [] user_word_level_stroke_in = [] user_word_level_stroke_out = [] user_word_level_stroke_length = [] user_word_level_term = [] user_word_level_char = [] user_word_level_char_length = [] user_segment_level_stroke_in = [] user_segment_level_stroke_out = [] user_segment_level_stroke_length = [] user_segment_level_term = [] user_segment_level_char = [] user_segment_level_char_length = [] # print ("uid: ", uid, "\ttids:", tids) for tid in tids: if self.datadir == './data/NEW_writers': [sentence_level_raw_stroke, sentence_level_stroke_in, sentence_level_stroke_out, sentence_level_term, sentence_level_char, word_level_raw_stroke, word_level_stroke_in, word_level_stroke_out, word_level_term, word_level_char, segment_level_raw_stroke, segment_level_stroke_in, segment_level_stroke_out, segment_level_term, segment_level_char] = \ np.load(self.datadir+'/'+str(uid)+'/'+str(tid)+'.npy', allow_pickle=True, encoding='bytes') elif self.datadir == './data/DW_writers': [sentence_level_raw_stroke, sentence_level_char, sentence_level_term, sentence_level_stroke_in, sentence_level_stroke_out, word_level_raw_stroke, word_level_char, word_level_term, word_level_stroke_in, word_level_stroke_out, segment_level_raw_stroke, segment_level_char, segment_level_term, segment_level_stroke_in, segment_level_stroke_out, _] = \ np.load(self.datadir+'/'+str(uid)+'/'+str(tid)+'.npy', allow_pickle=True, encoding='bytes') elif self.datadir == './data/VALID_DW_writers': [sentence_level_raw_stroke, sentence_level_char, sentence_level_term, sentence_level_stroke_in, sentence_level_stroke_out, word_level_raw_stroke, word_level_char, word_level_term, word_level_stroke_in, word_level_stroke_out, segment_level_raw_stroke, segment_level_char, segment_level_term, segment_level_stroke_in, segment_level_stroke_out, _] = \ np.load(self.datadir+'/'+str(uid)+'/'+str(tid)+'.npy', allow_pickle=True, encoding='bytes') else: [sentence_level_raw_stroke, sentence_level_stroke_in, sentence_level_stroke_out, sentence_level_term, sentence_level_char, word_level_raw_stroke, word_level_stroke_in, word_level_stroke_out, word_level_term, word_level_char, segment_level_raw_stroke, segment_level_stroke_in, segment_level_stroke_out, segment_level_term, segment_level_char, _] = \ np.load(self.datadir+'/'+str(uid)+'/'+str(tid)+'.npy', allow_pickle=True, encoding='bytes') if self.datadir == './data/DW_writers': sentence_level_char = sentence_level_char[1:] sentence_level_term = sentence_level_term[1:] if self.datadir == './data/VALID_DW_writers': sentence_level_char = sentence_level_char[1:] sentence_level_term = sentence_level_term[1:] while True: if len(sentence_level_term) == 0: break if sentence_level_term[-1] != 1.0: sentence_level_raw_stroke = sentence_level_raw_stroke[:-1] sentence_level_char = sentence_level_char[:-1] sentence_level_term = sentence_level_term[:-1] sentence_level_stroke_in = sentence_level_stroke_in[:-1] sentence_level_stroke_out = sentence_level_stroke_out[:-1] else: break tmp = [] for i, t in enumerate(sentence_level_term): if t == 1: tmp.append(sentence_level_char[i]) a = np.ones_like(sentence_level_stroke_in) a[:,:2] /= self.divider if len(sentence_level_stroke_in) == len(sentence_level_term) and len(tmp) > 0 and len(sentence_level_stroke_in) > 0: user_sentence_level_stroke_in.append(np.asarray(sentence_level_stroke_in) * a) user_sentence_level_stroke_out.append(np.asarray(sentence_level_stroke_out) * a) user_sentence_level_stroke_length.append(len(sentence_level_stroke_in)) user_sentence_level_char.append(np.asarray(tmp)) user_sentence_level_term.append(np.asarray(sentence_level_term)) user_sentence_level_char_length.append(len(tmp)) for wid in range(len(word_level_stroke_in)): each_word_level_stroke_in = word_level_stroke_in[wid] each_word_level_stroke_out = word_level_stroke_out[wid] if self.datadir == './data/DW_writers': each_word_level_term = word_level_term[wid][1:] each_word_level_char = word_level_char[wid][1:] elif self.datadir == './data/VALID_DW_writers': each_word_level_term = word_level_term[wid][1:] each_word_level_char = word_level_char[wid][1:] else: each_word_level_term = word_level_term[wid] each_word_level_char = word_level_char[wid] # assert (len(each_word_level_stroke_in) == len(each_word_level_char) == len(each_word_level_term)) while True: if len(each_word_level_term) == 0: break if each_word_level_term[-1] != 1.0: # each_word_level_raw_stroke = each_word_level_raw_stroke[:-1] each_word_level_char = each_word_level_char[:-1] each_word_level_term = each_word_level_term[:-1] each_word_level_stroke_in = each_word_level_stroke_in[:-1] each_word_level_stroke_out = each_word_level_stroke_out[:-1] else: break tmp = [] for i, t in enumerate(each_word_level_term): if t == 1: tmp.append(each_word_level_char[i]) b = np.ones_like(each_word_level_stroke_in) b[:,:2] /= self.divider if len(each_word_level_stroke_in) == len(each_word_level_term) and len(tmp) > 0 and len(each_word_level_stroke_in) > 0: user_word_level_stroke_in.append(np.asarray(each_word_level_stroke_in) * b) user_word_level_stroke_out.append(np.asarray(each_word_level_stroke_out) * b) user_word_level_stroke_length.append(len(each_word_level_stroke_in)) user_word_level_char.append(np.asarray(tmp)) user_word_level_term.append(np.asarray(each_word_level_term)) user_word_level_char_length.append(len(tmp)) segment_level_stroke_in_list = [] segment_level_stroke_out_list = [] segment_level_stroke_length_list = [] segment_level_char_list = [] segment_level_term_list = [] segment_level_char_length_list = [] for sid in range(len(segment_level_stroke_in[wid])): each_segment_level_stroke_in = segment_level_stroke_in[wid][sid] each_segment_level_stroke_out = segment_level_stroke_out[wid][sid] if self.datadir == './data/DW_writers': each_segment_level_term = segment_level_term[wid][sid][1:] each_segment_level_char = segment_level_char[wid][sid][1:] elif self.datadir == './data/VALID_DW_writers': each_segment_level_term = segment_level_term[wid][sid][1:] each_segment_level_char = segment_level_char[wid][sid][1:] else: each_segment_level_term = segment_level_term[wid][sid] each_segment_level_char = segment_level_char[wid][sid] while True: if len(each_segment_level_term) == 0: break if each_segment_level_term[-1] != 1.0: # each_segment_level_raw_stroke = each_segment_level_raw_stroke[:-1] each_segment_level_char = each_segment_level_char[:-1] each_segment_level_term = each_segment_level_term[:-1] each_segment_level_stroke_in = each_segment_level_stroke_in[:-1] each_segment_level_stroke_out = each_segment_level_stroke_out[:-1] else: break tmp = [] for i, t in enumerate(each_segment_level_term): if t == 1: tmp.append(each_segment_level_char[i]) c = np.ones_like(each_segment_level_stroke_in) c[:,:2] /= self.divider if len(each_segment_level_stroke_in) == len(each_segment_level_term) and len(tmp) > 0 and len(each_segment_level_stroke_in) > 0: segment_level_stroke_in_list.append(np.asarray(each_segment_level_stroke_in) * c) segment_level_stroke_out_list.append(np.asarray(each_segment_level_stroke_out) * c) segment_level_stroke_length_list.append(len(each_segment_level_stroke_in)) segment_level_char_list.append(np.asarray(tmp)) segment_level_term_list.append(np.asarray(each_segment_level_term)) segment_level_char_length_list.append(len(tmp)) if len(segment_level_stroke_length_list) > 0: SEGMENT_MAX_STROKE_LENGTH = np.max(segment_level_stroke_length_list) SEGMENT_MAX_CHARACTER_LENGTH = np.max(segment_level_char_length_list) new_segment_level_stroke_in_list = np.asarray([np.pad(a, ((0, SEGMENT_MAX_STROKE_LENGTH-len(a)), (0, 0)), 'constant') for a in segment_level_stroke_in_list]) new_segment_level_stroke_out_list = np.asarray([np.pad(a, ((0, SEGMENT_MAX_STROKE_LENGTH-len(a)), (0, 0)), 'constant') for a in segment_level_stroke_out_list]) new_segment_level_term_list = np.asarray([np.pad(a, ((0, SEGMENT_MAX_STROKE_LENGTH-len(a))), 'constant') for a in segment_level_term_list]) new_segment_level_char_list = np.asarray([np.pad(a, ((0, SEGMENT_MAX_CHARACTER_LENGTH-len(a))), 'constant') for a in segment_level_char_list]) user_segment_level_stroke_in.append(new_segment_level_stroke_in_list) user_segment_level_stroke_out.append(new_segment_level_stroke_out_list) user_segment_level_stroke_length.append(segment_level_stroke_length_list) user_segment_level_char.append(new_segment_level_char_list) user_segment_level_term.append(new_segment_level_term_list) user_segment_level_char_length.append(segment_level_char_length_list) WORD_MAX_STROKE_LENGTH = np.max(user_word_level_stroke_length) WORD_MAX_CHARACTER_LENGTH = np.max(user_word_level_char_length) SENTENCE_MAX_STROKE_LENGTH = np.max(user_sentence_level_stroke_length) SENTENCE_MAX_CHARACTER_LENGTH = np.max(user_sentence_level_char_length) new_sentence_level_stroke_in = np.asarray([np.pad(a, ((0, SENTENCE_MAX_STROKE_LENGTH-len(a)), (0,0)), 'constant') for a in user_sentence_level_stroke_in]) new_sentence_level_stroke_out = np.asarray([np.pad(a, ((0, SENTENCE_MAX_STROKE_LENGTH-len(a)), (0,0)), 'constant') for a in user_sentence_level_stroke_out]) new_sentence_level_term = np.asarray([np.pad(a, ((0, SENTENCE_MAX_STROKE_LENGTH-len(a))), 'constant') for a in user_sentence_level_term]) new_sentence_level_char = np.asarray([np.pad(a, ((0, SENTENCE_MAX_CHARACTER_LENGTH-len(a))), 'constant') for a in user_sentence_level_char]) new_word_level_stroke_in = np.asarray([np.pad(a, ((0, WORD_MAX_STROKE_LENGTH-len(a)), (0,0)), 'constant') for a in user_word_level_stroke_in]) new_word_level_stroke_out = np.asarray([np.pad(a, ((0, WORD_MAX_STROKE_LENGTH-len(a)), (0,0)), 'constant') for a in user_word_level_stroke_out]) new_word_level_term = np.asarray([np.pad(a, ((0, WORD_MAX_STROKE_LENGTH-len(a))), 'constant') for a in user_word_level_term]) new_word_level_char = np.asarray([np.pad(a, ((0, WORD_MAX_CHARACTER_LENGTH-len(a))), 'constant') for a in user_word_level_char]) all_sentence_level_stroke_in.append(new_sentence_level_stroke_in) all_sentence_level_stroke_out.append(new_sentence_level_stroke_out) all_sentence_level_stroke_length.append(user_sentence_level_stroke_length) all_sentence_level_term.append(new_sentence_level_term) all_sentence_level_char.append(new_sentence_level_char) all_sentence_level_char_length.append(user_sentence_level_char_length) all_word_level_stroke_in.append(new_word_level_stroke_in) all_word_level_stroke_out.append(new_word_level_stroke_out) all_word_level_stroke_length.append(user_word_level_stroke_length) all_word_level_term.append(new_word_level_term) all_word_level_char.append(new_word_level_char) all_word_level_char_length.append(user_word_level_char_length) all_segment_level_stroke_in.append(user_segment_level_stroke_in) all_segment_level_stroke_out.append(user_segment_level_stroke_out) all_segment_level_stroke_length.append(user_segment_level_stroke_length) all_segment_level_term.append(user_segment_level_term) all_segment_level_char.append(user_segment_level_char) all_segment_level_char_length.append(user_segment_level_char_length) return [all_sentence_level_stroke_in, all_sentence_level_stroke_out, all_sentence_level_stroke_length, all_sentence_level_term, all_sentence_level_char, all_sentence_level_char_length, all_word_level_stroke_in, all_word_level_stroke_out, all_word_level_stroke_length, all_word_level_term, all_word_level_char, all_word_level_char_length, all_segment_level_stroke_in, all_segment_level_stroke_out, all_segment_level_stroke_length, all_segment_level_term, all_segment_level_char, all_segment_level_char_length]