from config.GlobalVariables import * import math import torch from scipy import stats import numpy as np from PIL import Image, ImageDraw import os import pickle def preprocess_dataset(data_dir, resample=20, pred_start=1): def reformat_raw_data(raw_data, pred_start): if pred_start == 1: tmp = np.concatenate([[[0, 500, 0]], raw_data], 0) tmp = tmp[1:] - tmp[:-1] tmp[1:, 2] = raw_data[:-1, 2] tmp = np.concatenate([[[0, 0, 0]], tmp], 0) else: tmp = np.concatenate([raw_data[0:1], raw_data]) tmp = tmp[1:] - tmp[:-1] tmp[0,2] = 0 tmp[1:,2] = raw_data[:-1, 2] return tmp[:-1], tmp[1:] prohibits = [f'./BRUSH/5/118_resample{resample}', f'./BRUSH/7/14_resample{resample}', f'./BRUSH/7/101_resample{resample}', f'./BRUSH/7/58_resample{resample}', f'./BRUSH/14/20_resample{resample}', f'./BRUSH/22/45_resample{resample}', f'./BRUSH/30/45_resample{resample}', f'./BRUSH/40/85_resample{resample}', f'./BRUSH/50/45_resample{resample}', f'./BRUSH/59/29_resample{resample}', f'./BRUSH/96/120_resample{resample}', f'./BRUSH/99/134_resample{resample}', f'./BRUSH/140/35_resample{resample}', f'./BRUSH/144/55_resample{resample}', f'./BRUSH/144/91_resample{resample}', f'./BRUSH/144/28_resample{resample}', f'./BRUSH/144/69_resample{resample}'] preprocess_dir = 'preprocess' if pred_start == 1 else 'preprocess2' for writer_id in range(170): print(f'Preprocessing the BRUSH dataset - Finished Writer ID: {writer_id + 1} / 170') for sentence_id in [i for i in os.listdir(f'{data_dir}/{writer_id}') if i[-3:] == f'e{resample}']: with open(f'{data_dir}/{writer_id}/{sentence_id}', 'rb') as f: [sentence_text, raw_points, character_labels] = pickle.load(f) if f'{data_dir}/{writer_id}/{sentence_id}' not in prohibits: sentence_raw_points = raw_points sentence_raw_points[:, 0] -= sentence_raw_points[0, 0] sentence_stroke_in, sentence_stroke_out = reformat_raw_data(sentence_raw_points, pred_start=pred_start) split_char_ids = [i for i, c in enumerate(sentence_text) if c == ' '] sentence_char = [CHARACTERS.find(c) for c in sentence_text] sentence_term = [] cid = 0 for i in range(len(character_labels) - 1): if character_labels[i + 1, cid] != 1: if np.argmax(character_labels[i + 1]) >= cid: cid += 1 sentence_term.append(1) else: sentence_term.append(0) else: sentence_term.append(0) sentence_term.append(1) sentence_term = np.asarray(sentence_term) assert (len(sentence_term) == len(character_labels)) word_level_raw_stroke = [] word_level_stroke_in = [] word_level_stroke_out = [] word_level_char = [] word_level_term = [] segment_level_raw_stroke = [] segment_level_stroke_in = [] segment_level_stroke_out = [] segment_level_char = [] segment_level_term = [] character_level_raw_stroke = [] character_level_stroke_in = [] character_level_stroke_out = [] character_level_char = [] character_level_term = [] word_start_id = 0 for i, c in enumerate(sentence_text): if c != ' ': character_raw_points = raw_points[character_labels[:, i] > 0] character_raw_points[:, 0] -= character_raw_points[0, 0] character_level_raw_stroke.append(character_raw_points) character_stroke_in, character_stroke_out = reformat_raw_data(character_raw_points, pred_start=pred_start) character_level_stroke_in.append(character_stroke_in) character_level_stroke_out.append(character_stroke_out) term = np.zeros([len(character_raw_points)]) term[-1] = 1 character_level_term.append(term) character_level_char.append([CHARACTERS.find(c)]) if i in split_char_ids: word = sentence_text[word_start_id:i] word_labels = np.zeros(len(character_labels)) for j in range(word_start_id, i): word_labels += character_labels[:, j] word_raw_points = raw_points[word_labels > 0] word_term = sentence_term[word_labels > 0] word_term[0] = 0 assert (np.sum(word_term) == len(word)) word_raw_points[:, 0] -= word_raw_points[0, 0] word_level_raw_stroke.append(word_raw_points) word_stroke_in, word_stroke_out = reformat_raw_data(word_raw_points, pred_start=pred_start) word_level_stroke_in.append(word_stroke_in) word_level_stroke_out.append(word_stroke_out) word_level_term.append(word_term) word_level_char.append([CHARACTERS.find(c) for c in word]) word_start_id = i + 1 assert (len(character_level_raw_stroke) == len(word)) segment_level_raw_stroke.append(character_level_raw_stroke) segment_level_stroke_in.append(character_level_stroke_in) segment_level_stroke_out.append(character_level_stroke_out) segment_level_char.append(character_level_char) segment_level_term.append(character_level_term) character_level_raw_stroke = [] character_level_stroke_in = [] character_level_stroke_out = [] character_level_char = [] character_level_term = [] word = sentence_text[word_start_id:] word_labels = np.zeros(len(character_labels)) for j in range(word_start_id, len(sentence_text)): word_labels += character_labels[:, j] word_raw_points = raw_points[word_labels > 0] word_raw_points[:, 0] -= word_raw_points[0, 0] word_term = sentence_term[word_labels > 0] word_term[0] = 0 assert (np.sum(word_term) == len(word)) word_level_raw_stroke.append(word_raw_points) word_stroke_in, word_stroke_out = reformat_raw_data(word_raw_points, pred_start=pred_start) word_level_stroke_in.append(word_stroke_in) word_level_stroke_out.append(word_stroke_out) word_level_term.append(word_term) word_level_char.append([CHARACTERS.find(c) for c in word]) assert (len(character_level_raw_stroke) == len(word)) segment_level_raw_stroke.append(character_level_raw_stroke) segment_level_stroke_in.append(character_level_stroke_in) segment_level_stroke_out.append(character_level_stroke_out) segment_level_char.append(character_level_char) segment_level_term.append(character_level_term) if not os.path.exists(f'{data_dir}/{preprocess_dir}/{writer_id}'): os.mkdir(f'{data_dir}/{preprocess_dir}/{writer_id}') with open(f'{data_dir}/{preprocess_dir}/{writer_id}/{sentence_id}', 'wb') as f: pickle.dump([ sentence_stroke_in, sentence_stroke_out, sentence_term, sentence_char, word_level_stroke_in, word_level_stroke_out, word_level_term, word_level_char, segment_level_stroke_in, segment_level_stroke_out, segment_level_term, segment_level_char], f) def gaussian_2d(x1, x2, mu1, mu2, s1, s2, rho): norm1 = x1 - mu1 norm2 = x2 - mu2 s1s2 = s1 * s2 z = (norm1 / s1) ** 2 + (norm2 / s2) ** 2 - 2 * rho * norm1 * norm2 / s1s2 numerator = torch.exp(-z / (2 * (1 - rho ** 2))) denominator = 2 * math.pi * s1s2 * torch.sqrt(1 - rho ** 2) gaussian = numerator / denominator return gaussian def get_minimax(stroke_results): minimas = [] maximas = [] for stroke in stroke_results: for i, [x, y] in enumerate(stroke): if i == 0: prev_x, prev_y = x, y if i == len(stroke) - 1: if prev_y <= y: maximas.append([x, y]) if prev_y >= y: minimas.append([x, y]) break else: next_x, next_y = stroke[i + 1] if prev_y <= y and y >= next_y: maximas.append([x, y]) if prev_y >= y and y <= next_y: minimas.append([x, y]) minimas = np.asarray(minimas) maximas = np.asarray(maximas) return minimas, maximas def get_slope(minimas, maximas): minima_slope, minima_intercept, _, _, _ = stats.linregress(minimas[:, 0], minimas[:, 1]) maxima_slope, maxima_intercept, _, _, _ = stats.linregress(maximas[:, 0], maximas[:, 1]) min_se = [] max_se = [] for [x, y] in minimas: min_ny = minima_slope * x + minima_intercept min_se.append(abs(min_ny - y)) for [x, y] in maximas: max_ny = maxima_slope * x + maxima_intercept max_se.append(abs(max_ny - y)) min_se, max_se = np.asarray(min_se), np.asarray(max_se) new_minimas = minimas[min_se < np.mean(min_se)] new_maximas = maximas[max_se < np.mean(max_se)] if len(new_minimas) > 5: minima_slope, minima_intercept, _, _, _ = stats.linregress(new_minimas[:, 0], new_minimas[:, 1]) if len(new_maximas) > 5: maxima_slope, maxima_intercept, _, _, _ = stats.linregress(new_maximas[:, 0], new_maximas[:, 1]) return minima_slope, minima_intercept, maxima_slope, maxima_intercept def draw_commands(commands): im = Image.fromarray(np.zeros([160, 750])) dr = ImageDraw.Draw(im) px, py = 50, 100 for i, [dx, dy, t] in enumerate(commands): x = px + dx * 5 y = py + dy * 5 if t == 0: dr.line((px, py, x, y), 255, 1) px, py = x, y return im def draw_points(raw_points, character_labels): [w, h, _] = np.max(raw_points, 0) im = Image.new("RGB", [int(w) + 100, int(h) + 100]) dr = ImageDraw.Draw(im) colors = np.random.randint(0, 255, (len(character_labels[0]), 3)) for i, [x, y, t] in enumerate(raw_points): if i > 0: if pt == 0: dr.line((px, py, x, y), tuple(colors[np.argmax(character_labels[i])]), 3) px, py, pt = x, y, t return im