3morrrrr's picture
Update hand.py
f94c6e7 verified
import drawing
from rnn import rnn
import numpy as np
import svgwrite
import logging
import os
class Hand(object):
def __init__(self):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
self.nn = rnn(
log_dir='logs',
checkpoint_dir='checkpoints',
prediction_dir='predictions',
learning_rates=[.0001, .00005, .00002],
batch_sizes=[32, 64, 64],
patiences=[1500, 1000, 500],
beta1_decays=[.9, .9, .9],
validation_batch_size=32,
optimizer='rms',
num_training_steps=100000,
warm_start_init_step=17900,
regularization_constant=0.0,
keep_prob=1.0,
enable_parameter_averaging=False,
min_steps_to_checkpoint=2000,
log_interval=20,
logging_level=logging.CRITICAL,
grad_clip=10,
lstm_size=400,
output_mixture_components=20,
attention_mixture_components=10
)
self.nn.restore()
def write(self, filename, lines, biases=None, styles=None, stroke_colors=None, stroke_widths=None):
valid_char_set = set(drawing.alphabet)
for line_num, line in enumerate(lines):
if len(line) > 75:
raise ValueError(
(
"Each line must be at most 75 characters. "
"Line {} contains {}"
).format(line_num, len(line))
)
for char in line:
if char not in valid_char_set:
raise ValueError(
(
"Invalid character {} detected in line {}. "
"Valid character set is {}"
).format(char, line_num, valid_char_set)
)
strokes = self._sample(lines, biases=biases, styles=styles)
self._draw(strokes, lines, filename, stroke_colors=stroke_colors, stroke_widths=stroke_widths)
def _sample(self, lines, biases=None, styles=None):
num_samples = len(lines)
max_tsteps = 40*max([len(i) for i in lines])
biases = biases if biases is not None else [0.5]*num_samples
x_prime = np.zeros([num_samples, 1200, 3])
x_prime_len = np.zeros([num_samples])
chars = np.zeros([num_samples, 120])
chars_len = np.zeros([num_samples])
if styles is not None:
for i, (cs, style) in enumerate(zip(lines, styles)):
x_p = np.load('styles/style-{}-strokes.npy'.format(style))
c_p = np.load('styles/style-{}-chars.npy'.format(style)).tostring().decode('utf-8')
c_p = str(c_p) + " " + cs
c_p = drawing.encode_ascii(c_p)
c_p = np.array(c_p)
x_prime[i, :len(x_p), :] = x_p
x_prime_len[i] = len(x_p)
chars[i, :len(c_p)] = c_p
chars_len[i] = len(c_p)
else:
for i in range(num_samples):
encoded = drawing.encode_ascii(lines[i])
chars[i, :len(encoded)] = encoded
chars_len[i] = len(encoded)
[samples] = self.nn.session.run(
[self.nn.sampled_sequence],
feed_dict={
self.nn.prime: styles is not None,
self.nn.x_prime: x_prime,
self.nn.x_prime_len: x_prime_len,
self.nn.num_samples: num_samples,
self.nn.sample_tsteps: max_tsteps,
self.nn.c: chars,
self.nn.c_len: chars_len,
self.nn.bias: biases
}
)
samples = [sample[~np.all(sample == 0.0, axis=1)] for sample in samples]
return samples
def _draw(self, strokes, lines, filename, stroke_colors=None, stroke_widths=None):
stroke_colors = stroke_colors or ['black']*len(lines)
stroke_widths = stroke_widths or [4]*len(lines) # Increased default from 2 to 4
line_height = 80 # Increased from 60 to 80
view_width = 600 # Reduced from 1000 to 800
view_height = line_height*(len(strokes) + 1)
dwg = svgwrite.Drawing(filename=filename)
dwg.viewbox(width=view_width, height=view_height)
dwg.add(dwg.rect(insert=(0, 0), size=(view_width, view_height), fill='white'))
initial_coord = np.array([0, -(3*line_height / 4)])
for offsets, line, color, width in zip(strokes, lines, stroke_colors, stroke_widths):
if not line:
initial_coord[1] -= line_height
continue
offsets[:, :2] *= 3.0 # Increased from 1.5 to 3.0
strokes = drawing.offsets_to_coords(offsets)
strokes = drawing.denoise(strokes)
strokes[:, :2] = drawing.align(strokes[:, :2])
strokes[:, 1] *= -1
strokes[:, :2] -= strokes[:, :2].min() + initial_coord
# Center the text horizontally on the canvas (original centering logic)
strokes[:, 0] += (view_width - strokes[:, 0].max()) / 2
prev_eos = 1.0
p = "M{},{} ".format(0, 0)
for x, y, eos in zip(*strokes.T):
p += '{}{},{} '.format('M' if prev_eos == 1.0 else 'L', x, y)
prev_eos = eos
path = svgwrite.path.Path(p)
path = path.stroke(color=color, width=width, linecap='round').fill("none")
dwg.add(path)
initial_coord[1] -= line_height
dwg.save()