Handwriting_Model_Inf / prepare_data.py
3morrrrr's picture
Upload 14 files
569596a verified
from __future__ import print_function
import os
from xml.etree import ElementTree
import numpy as np
import drawing
def get_stroke_sequence(filename):
tree = ElementTree.parse(filename).getroot()
strokes = [i for i in tree if i.tag == 'StrokeSet'][0]
coords = []
for stroke in strokes:
for i, point in enumerate(stroke):
coords.append([
int(point.attrib['x']),
-1*int(point.attrib['y']),
int(i == len(stroke) - 1)
])
coords = np.array(coords)
coords = drawing.align(coords)
coords = drawing.denoise(coords)
offsets = drawing.coords_to_offsets(coords)
offsets = offsets[:drawing.MAX_STROKE_LEN]
offsets = drawing.normalize(offsets)
return offsets
def get_ascii_sequences(filename):
sequences = open(filename, 'r').read()
sequences = sequences.replace(r'%%%%%%%%%%%', '\n')
sequences = [i.strip() for i in sequences.split('\n')]
lines = sequences[sequences.index('CSR:') + 2:]
lines = [line.strip() for line in lines if line.strip()]
lines = [drawing.encode_ascii(line)[:drawing.MAX_CHAR_LEN] for line in lines]
return lines
def collect_data():
fnames = []
for dirpath, dirnames, filenames in os.walk('data/raw/ascii/'):
if dirnames:
continue
for filename in filenames:
if filename.startswith('.'):
continue
fnames.append(os.path.join(dirpath, filename))
# low quality samples (selected by collecting samples to
# which the trained model assigned very low likelihood)
blacklist = set(np.load('data/blacklist.npy'))
stroke_fnames, transcriptions, writer_ids = [], [], []
for i, fname in enumerate(fnames):
print(i, fname)
if fname == 'data/raw/ascii/z01/z01-000/z01-000z.txt':
continue
head, tail = os.path.split(fname)
last_letter = os.path.splitext(fname)[0][-1]
last_letter = last_letter if last_letter.isalpha() else ''
line_stroke_dir = head.replace('ascii', 'lineStrokes')
line_stroke_fname_prefix = os.path.split(head)[-1] + last_letter + '-'
if not os.path.isdir(line_stroke_dir):
continue
line_stroke_fnames = sorted([f for f in os.listdir(line_stroke_dir)
if f.startswith(line_stroke_fname_prefix)])
if not line_stroke_fnames:
continue
original_dir = head.replace('ascii', 'original')
original_xml = os.path.join(original_dir, 'strokes' + last_letter + '.xml')
tree = ElementTree.parse(original_xml)
root = tree.getroot()
general = root.find('General')
if general is not None:
writer_id = int(general[0].attrib.get('writerID', '0'))
else:
writer_id = int('0')
ascii_sequences = get_ascii_sequences(fname)
assert len(ascii_sequences) == len(line_stroke_fnames)
for ascii_seq, line_stroke_fname in zip(ascii_sequences, line_stroke_fnames):
if line_stroke_fname in blacklist:
continue
stroke_fnames.append(os.path.join(line_stroke_dir, line_stroke_fname))
transcriptions.append(ascii_seq)
writer_ids.append(writer_id)
return stroke_fnames, transcriptions, writer_ids
if __name__ == '__main__':
print('traversing data directory...')
stroke_fnames, transcriptions, writer_ids = collect_data()
print('dumping to numpy arrays...')
x = np.zeros([len(stroke_fnames), drawing.MAX_STROKE_LEN, 3], dtype=np.float32)
x_len = np.zeros([len(stroke_fnames)], dtype=np.int16)
c = np.zeros([len(stroke_fnames), drawing.MAX_CHAR_LEN], dtype=np.int8)
c_len = np.zeros([len(stroke_fnames)], dtype=np.int8)
w_id = np.zeros([len(stroke_fnames)], dtype=np.int16)
valid_mask = np.zeros([len(stroke_fnames)], dtype=np.bool)
for i, (stroke_fname, c_i, w_id_i) in enumerate(zip(stroke_fnames, transcriptions, writer_ids)):
if i % 200 == 0:
print(i, '\t', '/', len(stroke_fnames))
x_i = get_stroke_sequence(stroke_fname)
valid_mask[i] = ~np.any(np.linalg.norm(x_i[:, :2], axis=1) > 60)
x[i, :len(x_i), :] = x_i
x_len[i] = len(x_i)
c[i, :len(c_i)] = c_i
c_len[i] = len(c_i)
w_id[i] = w_id_i
if not os.path.isdir('data/processed'):
os.makedirs('data/processed')
np.save('data/processed/x.npy', x[valid_mask])
np.save('data/processed/x_len.npy', x_len[valid_mask])
np.save('data/processed/c.npy', c[valid_mask])
np.save('data/processed/c_len.npy', c_len[valid_mask])
np.save('data/processed/w_id.npy', w_id[valid_mask])