brayden-gg
added files
b65c5e3
raw
history blame
14.7 kB
import os
import numpy as np
import torch
import random
from PIL import Image, ImageDraw, ImageFont
import pickle
from config.GlobalVariables import *
np.random.seed(0)
class DataLoader():
def __init__(self, num_writer=2, num_samples=5, divider=10.0, datadir='./data/writers'):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.num_writer = num_writer
self.num_samples = num_samples
self.divider = divider
self.datadir = datadir
print ('self.datadir : ', self.datadir)
self.total_writers = len([name for name in os.listdir(datadir)])
def next_batch(self, TYPE='TRAIN', uid=-1, tids=[]):
all_sentence_level_stroke_in = []
all_sentence_level_stroke_out = []
all_sentence_level_stroke_length = []
all_sentence_level_term = []
all_sentence_level_char = []
all_sentence_level_char_length = []
all_word_level_stroke_in = []
all_word_level_stroke_out = []
all_word_level_stroke_length = []
all_word_level_term = []
all_word_level_char = []
all_word_level_char_length = []
all_segment_level_stroke_in = []
all_segment_level_stroke_out = []
all_segment_level_stroke_length = []
all_segment_level_term = []
all_segment_level_char = []
all_segment_level_char_length = []
while len(all_sentence_level_stroke_in) < self.num_writer:
if uid < 0:
if TYPE == 'TRAIN':
if self.datadir == './data/NEW_writers' or self.datadir == './data/writers':
uid = np.random.choice([i for i in range(150)])
else:
if self.device == 'cpu':
uid = np.random.choice([i for i in range(20)])
else:
uid = np.random.choice([i for i in range(294)])
else:
uid = np.random.choice([i for i in range(150,170)])
total_texts = len([name for name in os.listdir(self.datadir+'/'+str(uid))])
if len(tids) == 0:
tids = random.sample([i for i in range(total_texts)], self.num_samples)
user_sentence_level_stroke_in = []
user_sentence_level_stroke_out = []
user_sentence_level_stroke_length = []
user_sentence_level_term = []
user_sentence_level_char = []
user_sentence_level_char_length = []
user_word_level_stroke_in = []
user_word_level_stroke_out = []
user_word_level_stroke_length = []
user_word_level_term = []
user_word_level_char = []
user_word_level_char_length = []
user_segment_level_stroke_in = []
user_segment_level_stroke_out = []
user_segment_level_stroke_length = []
user_segment_level_term = []
user_segment_level_char = []
user_segment_level_char_length = []
# print ("uid: ", uid, "\ttids:", tids)
for tid in tids:
if self.datadir == './data/NEW_writers':
[sentence_level_raw_stroke, sentence_level_stroke_in, sentence_level_stroke_out, sentence_level_term, sentence_level_char, word_level_raw_stroke, word_level_stroke_in, word_level_stroke_out, word_level_term, word_level_char, segment_level_raw_stroke, segment_level_stroke_in, segment_level_stroke_out, segment_level_term, segment_level_char] = \
np.load(self.datadir+'/'+str(uid)+'/'+str(tid)+'.npy', allow_pickle=True, encoding='bytes')
elif self.datadir == './data/DW_writers':
[sentence_level_raw_stroke, sentence_level_char, sentence_level_term, sentence_level_stroke_in, sentence_level_stroke_out,
word_level_raw_stroke, word_level_char, word_level_term, word_level_stroke_in, word_level_stroke_out,
segment_level_raw_stroke, segment_level_char, segment_level_term, segment_level_stroke_in, segment_level_stroke_out, _] = \
np.load(self.datadir+'/'+str(uid)+'/'+str(tid)+'.npy', allow_pickle=True, encoding='bytes')
elif self.datadir == './data/VALID_DW_writers':
[sentence_level_raw_stroke, sentence_level_char, sentence_level_term, sentence_level_stroke_in, sentence_level_stroke_out,
word_level_raw_stroke, word_level_char, word_level_term, word_level_stroke_in, word_level_stroke_out,
segment_level_raw_stroke, segment_level_char, segment_level_term, segment_level_stroke_in, segment_level_stroke_out, _] = \
np.load(self.datadir+'/'+str(uid)+'/'+str(tid)+'.npy', allow_pickle=True, encoding='bytes')
else:
[sentence_level_raw_stroke, sentence_level_stroke_in, sentence_level_stroke_out, sentence_level_term, sentence_level_char, word_level_raw_stroke, word_level_stroke_in, word_level_stroke_out, word_level_term, word_level_char, segment_level_raw_stroke, segment_level_stroke_in, segment_level_stroke_out, segment_level_term, segment_level_char, _] = \
np.load(self.datadir+'/'+str(uid)+'/'+str(tid)+'.npy', allow_pickle=True, encoding='bytes')
if self.datadir == './data/DW_writers':
sentence_level_char = sentence_level_char[1:]
sentence_level_term = sentence_level_term[1:]
if self.datadir == './data/VALID_DW_writers':
sentence_level_char = sentence_level_char[1:]
sentence_level_term = sentence_level_term[1:]
while True:
if len(sentence_level_term) == 0:
break
if sentence_level_term[-1] != 1.0:
sentence_level_raw_stroke = sentence_level_raw_stroke[:-1]
sentence_level_char = sentence_level_char[:-1]
sentence_level_term = sentence_level_term[:-1]
sentence_level_stroke_in = sentence_level_stroke_in[:-1]
sentence_level_stroke_out = sentence_level_stroke_out[:-1]
else:
break
tmp = []
for i, t in enumerate(sentence_level_term):
if t == 1:
tmp.append(sentence_level_char[i])
a = np.ones_like(sentence_level_stroke_in)
a[:,:2] /= self.divider
if len(sentence_level_stroke_in) == len(sentence_level_term) and len(tmp) > 0 and len(sentence_level_stroke_in) > 0:
user_sentence_level_stroke_in.append(np.asarray(sentence_level_stroke_in) * a)
user_sentence_level_stroke_out.append(np.asarray(sentence_level_stroke_out) * a)
user_sentence_level_stroke_length.append(len(sentence_level_stroke_in))
user_sentence_level_char.append(np.asarray(tmp))
user_sentence_level_term.append(np.asarray(sentence_level_term))
user_sentence_level_char_length.append(len(tmp))
for wid in range(len(word_level_stroke_in)):
each_word_level_stroke_in = word_level_stroke_in[wid]
each_word_level_stroke_out = word_level_stroke_out[wid]
if self.datadir == './data/DW_writers':
each_word_level_term = word_level_term[wid][1:]
each_word_level_char = word_level_char[wid][1:]
elif self.datadir == './data/VALID_DW_writers':
each_word_level_term = word_level_term[wid][1:]
each_word_level_char = word_level_char[wid][1:]
else:
each_word_level_term = word_level_term[wid]
each_word_level_char = word_level_char[wid]
# assert (len(each_word_level_stroke_in) == len(each_word_level_char) == len(each_word_level_term))
while True:
if len(each_word_level_term) == 0:
break
if each_word_level_term[-1] != 1.0:
# each_word_level_raw_stroke = each_word_level_raw_stroke[:-1]
each_word_level_char = each_word_level_char[:-1]
each_word_level_term = each_word_level_term[:-1]
each_word_level_stroke_in = each_word_level_stroke_in[:-1]
each_word_level_stroke_out = each_word_level_stroke_out[:-1]
else:
break
tmp = []
for i, t in enumerate(each_word_level_term):
if t == 1:
tmp.append(each_word_level_char[i])
b = np.ones_like(each_word_level_stroke_in)
b[:,:2] /= self.divider
if len(each_word_level_stroke_in) == len(each_word_level_term) and len(tmp) > 0 and len(each_word_level_stroke_in) > 0:
user_word_level_stroke_in.append(np.asarray(each_word_level_stroke_in) * b)
user_word_level_stroke_out.append(np.asarray(each_word_level_stroke_out) * b)
user_word_level_stroke_length.append(len(each_word_level_stroke_in))
user_word_level_char.append(np.asarray(tmp))
user_word_level_term.append(np.asarray(each_word_level_term))
user_word_level_char_length.append(len(tmp))
segment_level_stroke_in_list = []
segment_level_stroke_out_list = []
segment_level_stroke_length_list = []
segment_level_char_list = []
segment_level_term_list = []
segment_level_char_length_list = []
for sid in range(len(segment_level_stroke_in[wid])):
each_segment_level_stroke_in = segment_level_stroke_in[wid][sid]
each_segment_level_stroke_out = segment_level_stroke_out[wid][sid]
if self.datadir == './data/DW_writers':
each_segment_level_term = segment_level_term[wid][sid][1:]
each_segment_level_char = segment_level_char[wid][sid][1:]
elif self.datadir == './data/VALID_DW_writers':
each_segment_level_term = segment_level_term[wid][sid][1:]
each_segment_level_char = segment_level_char[wid][sid][1:]
else:
each_segment_level_term = segment_level_term[wid][sid]
each_segment_level_char = segment_level_char[wid][sid]
while True:
if len(each_segment_level_term) == 0:
break
if each_segment_level_term[-1] != 1.0:
# each_segment_level_raw_stroke = each_segment_level_raw_stroke[:-1]
each_segment_level_char = each_segment_level_char[:-1]
each_segment_level_term = each_segment_level_term[:-1]
each_segment_level_stroke_in = each_segment_level_stroke_in[:-1]
each_segment_level_stroke_out = each_segment_level_stroke_out[:-1]
else:
break
tmp = []
for i, t in enumerate(each_segment_level_term):
if t == 1:
tmp.append(each_segment_level_char[i])
c = np.ones_like(each_segment_level_stroke_in)
c[:,:2] /= self.divider
if len(each_segment_level_stroke_in) == len(each_segment_level_term) and len(tmp) > 0 and len(each_segment_level_stroke_in) > 0:
segment_level_stroke_in_list.append(np.asarray(each_segment_level_stroke_in) * c)
segment_level_stroke_out_list.append(np.asarray(each_segment_level_stroke_out) * c)
segment_level_stroke_length_list.append(len(each_segment_level_stroke_in))
segment_level_char_list.append(np.asarray(tmp))
segment_level_term_list.append(np.asarray(each_segment_level_term))
segment_level_char_length_list.append(len(tmp))
if len(segment_level_stroke_length_list) > 0:
SEGMENT_MAX_STROKE_LENGTH = np.max(segment_level_stroke_length_list)
SEGMENT_MAX_CHARACTER_LENGTH = np.max(segment_level_char_length_list)
new_segment_level_stroke_in_list = np.asarray([np.pad(a, ((0, SEGMENT_MAX_STROKE_LENGTH-len(a)), (0, 0)), 'constant') for a in segment_level_stroke_in_list])
new_segment_level_stroke_out_list = np.asarray([np.pad(a, ((0, SEGMENT_MAX_STROKE_LENGTH-len(a)), (0, 0)), 'constant') for a in segment_level_stroke_out_list])
new_segment_level_term_list = np.asarray([np.pad(a, ((0, SEGMENT_MAX_STROKE_LENGTH-len(a))), 'constant') for a in segment_level_term_list])
new_segment_level_char_list = np.asarray([np.pad(a, ((0, SEGMENT_MAX_CHARACTER_LENGTH-len(a))), 'constant') for a in segment_level_char_list])
user_segment_level_stroke_in.append(new_segment_level_stroke_in_list)
user_segment_level_stroke_out.append(new_segment_level_stroke_out_list)
user_segment_level_stroke_length.append(segment_level_stroke_length_list)
user_segment_level_char.append(new_segment_level_char_list)
user_segment_level_term.append(new_segment_level_term_list)
user_segment_level_char_length.append(segment_level_char_length_list)
WORD_MAX_STROKE_LENGTH = np.max(user_word_level_stroke_length)
WORD_MAX_CHARACTER_LENGTH = np.max(user_word_level_char_length)
SENTENCE_MAX_STROKE_LENGTH = np.max(user_sentence_level_stroke_length)
SENTENCE_MAX_CHARACTER_LENGTH = np.max(user_sentence_level_char_length)
new_sentence_level_stroke_in = np.asarray([np.pad(a, ((0, SENTENCE_MAX_STROKE_LENGTH-len(a)), (0,0)), 'constant') for a in user_sentence_level_stroke_in])
new_sentence_level_stroke_out = np.asarray([np.pad(a, ((0, SENTENCE_MAX_STROKE_LENGTH-len(a)), (0,0)), 'constant') for a in user_sentence_level_stroke_out])
new_sentence_level_term = np.asarray([np.pad(a, ((0, SENTENCE_MAX_STROKE_LENGTH-len(a))), 'constant') for a in user_sentence_level_term])
new_sentence_level_char = np.asarray([np.pad(a, ((0, SENTENCE_MAX_CHARACTER_LENGTH-len(a))), 'constant') for a in user_sentence_level_char])
new_word_level_stroke_in = np.asarray([np.pad(a, ((0, WORD_MAX_STROKE_LENGTH-len(a)), (0,0)), 'constant') for a in user_word_level_stroke_in])
new_word_level_stroke_out = np.asarray([np.pad(a, ((0, WORD_MAX_STROKE_LENGTH-len(a)), (0,0)), 'constant') for a in user_word_level_stroke_out])
new_word_level_term = np.asarray([np.pad(a, ((0, WORD_MAX_STROKE_LENGTH-len(a))), 'constant') for a in user_word_level_term])
new_word_level_char = np.asarray([np.pad(a, ((0, WORD_MAX_CHARACTER_LENGTH-len(a))), 'constant') for a in user_word_level_char])
all_sentence_level_stroke_in.append(new_sentence_level_stroke_in)
all_sentence_level_stroke_out.append(new_sentence_level_stroke_out)
all_sentence_level_stroke_length.append(user_sentence_level_stroke_length)
all_sentence_level_term.append(new_sentence_level_term)
all_sentence_level_char.append(new_sentence_level_char)
all_sentence_level_char_length.append(user_sentence_level_char_length)
all_word_level_stroke_in.append(new_word_level_stroke_in)
all_word_level_stroke_out.append(new_word_level_stroke_out)
all_word_level_stroke_length.append(user_word_level_stroke_length)
all_word_level_term.append(new_word_level_term)
all_word_level_char.append(new_word_level_char)
all_word_level_char_length.append(user_word_level_char_length)
all_segment_level_stroke_in.append(user_segment_level_stroke_in)
all_segment_level_stroke_out.append(user_segment_level_stroke_out)
all_segment_level_stroke_length.append(user_segment_level_stroke_length)
all_segment_level_term.append(user_segment_level_term)
all_segment_level_char.append(user_segment_level_char)
all_segment_level_char_length.append(user_segment_level_char_length)
return [all_sentence_level_stroke_in, all_sentence_level_stroke_out, all_sentence_level_stroke_length, all_sentence_level_term, all_sentence_level_char, all_sentence_level_char_length, all_word_level_stroke_in, all_word_level_stroke_out, all_word_level_stroke_length, all_word_level_term, all_word_level_char, all_word_level_char_length, all_segment_level_stroke_in, all_segment_level_stroke_out, all_segment_level_stroke_length, all_segment_level_term, all_segment_level_char, all_segment_level_char_length]