Spaces:
Running
Running
from __future__ import print_function | |
import os | |
from xml.etree import ElementTree | |
import numpy as np | |
import drawing | |
def get_stroke_sequence(filename): | |
tree = ElementTree.parse(filename).getroot() | |
strokes = [i for i in tree if i.tag == 'StrokeSet'][0] | |
coords = [] | |
for stroke in strokes: | |
for i, point in enumerate(stroke): | |
coords.append([ | |
int(point.attrib['x']), | |
-1*int(point.attrib['y']), | |
int(i == len(stroke) - 1) | |
]) | |
coords = np.array(coords) | |
coords = drawing.align(coords) | |
coords = drawing.denoise(coords) | |
offsets = drawing.coords_to_offsets(coords) | |
offsets = offsets[:drawing.MAX_STROKE_LEN] | |
offsets = drawing.normalize(offsets) | |
return offsets | |
def get_ascii_sequences(filename): | |
sequences = open(filename, 'r').read() | |
sequences = sequences.replace(r'%%%%%%%%%%%', '\n') | |
sequences = [i.strip() for i in sequences.split('\n')] | |
lines = sequences[sequences.index('CSR:') + 2:] | |
lines = [line.strip() for line in lines if line.strip()] | |
lines = [drawing.encode_ascii(line)[:drawing.MAX_CHAR_LEN] for line in lines] | |
return lines | |
def collect_data(): | |
fnames = [] | |
for dirpath, dirnames, filenames in os.walk('data/raw/ascii/'): | |
if dirnames: | |
continue | |
for filename in filenames: | |
if filename.startswith('.'): | |
continue | |
fnames.append(os.path.join(dirpath, filename)) | |
# low quality samples (selected by collecting samples to | |
# which the trained model assigned very low likelihood) | |
blacklist = set(np.load('data/blacklist.npy')) | |
stroke_fnames, transcriptions, writer_ids = [], [], [] | |
for i, fname in enumerate(fnames): | |
print(i, fname) | |
if fname == 'data/raw/ascii/z01/z01-000/z01-000z.txt': | |
continue | |
head, tail = os.path.split(fname) | |
last_letter = os.path.splitext(fname)[0][-1] | |
last_letter = last_letter if last_letter.isalpha() else '' | |
line_stroke_dir = head.replace('ascii', 'lineStrokes') | |
line_stroke_fname_prefix = os.path.split(head)[-1] + last_letter + '-' | |
if not os.path.isdir(line_stroke_dir): | |
continue | |
line_stroke_fnames = sorted([f for f in os.listdir(line_stroke_dir) | |
if f.startswith(line_stroke_fname_prefix)]) | |
if not line_stroke_fnames: | |
continue | |
original_dir = head.replace('ascii', 'original') | |
original_xml = os.path.join(original_dir, 'strokes' + last_letter + '.xml') | |
tree = ElementTree.parse(original_xml) | |
root = tree.getroot() | |
general = root.find('General') | |
if general is not None: | |
writer_id = int(general[0].attrib.get('writerID', '0')) | |
else: | |
writer_id = int('0') | |
ascii_sequences = get_ascii_sequences(fname) | |
assert len(ascii_sequences) == len(line_stroke_fnames) | |
for ascii_seq, line_stroke_fname in zip(ascii_sequences, line_stroke_fnames): | |
if line_stroke_fname in blacklist: | |
continue | |
stroke_fnames.append(os.path.join(line_stroke_dir, line_stroke_fname)) | |
transcriptions.append(ascii_seq) | |
writer_ids.append(writer_id) | |
return stroke_fnames, transcriptions, writer_ids | |
if __name__ == '__main__': | |
print('traversing data directory...') | |
stroke_fnames, transcriptions, writer_ids = collect_data() | |
print('dumping to numpy arrays...') | |
x = np.zeros([len(stroke_fnames), drawing.MAX_STROKE_LEN, 3], dtype=np.float32) | |
x_len = np.zeros([len(stroke_fnames)], dtype=np.int16) | |
c = np.zeros([len(stroke_fnames), drawing.MAX_CHAR_LEN], dtype=np.int8) | |
c_len = np.zeros([len(stroke_fnames)], dtype=np.int8) | |
w_id = np.zeros([len(stroke_fnames)], dtype=np.int16) | |
valid_mask = np.zeros([len(stroke_fnames)], dtype=np.bool) | |
for i, (stroke_fname, c_i, w_id_i) in enumerate(zip(stroke_fnames, transcriptions, writer_ids)): | |
if i % 200 == 0: | |
print(i, '\t', '/', len(stroke_fnames)) | |
x_i = get_stroke_sequence(stroke_fname) | |
valid_mask[i] = ~np.any(np.linalg.norm(x_i[:, :2], axis=1) > 60) | |
x[i, :len(x_i), :] = x_i | |
x_len[i] = len(x_i) | |
c[i, :len(c_i)] = c_i | |
c_len[i] = len(c_i) | |
w_id[i] = w_id_i | |
if not os.path.isdir('data/processed'): | |
os.makedirs('data/processed') | |
np.save('data/processed/x.npy', x[valid_mask]) | |
np.save('data/processed/x_len.npy', x_len[valid_mask]) | |
np.save('data/processed/c.npy', c[valid_mask]) | |
np.save('data/processed/c_len.npy', c_len[valid_mask]) | |
np.save('data/processed/w_id.npy', w_id[valid_mask]) | |