Spaces:
Running
Running
Create script.js
Browse files
script.js
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
var log = console.log;
|
2 |
+
var ctx = null;
|
3 |
+
var canvas = null;
|
4 |
+
var RNN_SIZE = 512;
|
5 |
+
var cur_run = 0;
|
6 |
+
|
7 |
+
var randn = function() {
|
8 |
+
// Standard Normal random variable using Box-Muller transform.
|
9 |
+
var u = Math.random() * 0.999 + 1e-5;
|
10 |
+
var v = Math.random() * 0.999 + 1e-5;
|
11 |
+
return Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v);
|
12 |
+
}
|
13 |
+
|
14 |
+
var rand_truncated_normal = function(low, high) {
|
15 |
+
while (true) {
|
16 |
+
r = randn();
|
17 |
+
if (r >= low && r <= high)
|
18 |
+
break;
|
19 |
+
// rejection sampling.
|
20 |
+
}
|
21 |
+
return r;
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
var char2idx = {'\x00': 0, ' ': 1, '!': 2, '"': 3, '#': 4, "'": 5, '(': 6, ')': 7, ',': 8, '-': 9, '.': 10, '0': 11, '1': 12, '2': 13, '3': 14, '4': 15, '5': 16, '6': 17, '7': 18, '8': 19, '9': 20, ':': 21, ';': 22, '?': 23, 'A': 24, 'B': 25, 'C': 26, 'D': 27, 'E': 28, 'F': 29, 'G': 30, 'H': 31, 'I': 32, 'J': 33, 'K': 34, 'L': 35, 'M': 36, 'N': 37, 'O': 38, 'P': 39, 'R': 40, 'S': 41, 'T': 42, 'U': 43, 'V': 44, 'W': 45, 'Y': 46, 'a': 47, 'b': 48, 'c': 49, 'd': 50, 'e': 51, 'f': 52, 'g': 53, 'h': 54, 'i': 55, 'j': 56, 'k': 57, 'l': 58, 'm': 59, 'n': 60, 'o': 61, 'p': 62, 'q': 63, 'r': 64, 's': 65, 't': 66, 'u': 67, 'v': 68, 'w': 69, 'x': 70, 'y': 71, 'z': 72};
|
27 |
+
|
28 |
+
var gru_core = function(input, weights, state, hidden_size) {
|
29 |
+
var [w_h,w_i,b] = weights;
|
30 |
+
var [w_h_z,w_h_a] = tf.split(w_h, [2 * hidden_size, hidden_size], 1);
|
31 |
+
var [b_z,b_a] = tf.split(b, [2 * hidden_size, hidden_size], 0);
|
32 |
+
gates_x = tf.matMul(input, w_i);
|
33 |
+
[zr_x,a_x] = tf.split(gates_x, [2 * hidden_size, hidden_size], 1);
|
34 |
+
zr_h = tf.matMul(state, w_h_z);
|
35 |
+
zr = tf.add(tf.add(zr_x, zr_h), b_z);
|
36 |
+
// fix this
|
37 |
+
[z,r] = tf.split(tf.sigmoid(zr), 2, 1);
|
38 |
+
a_h = tf.matMul(tf.mul(r, state), w_h_a);
|
39 |
+
a = tf.tanh(tf.add(tf.add(a_x, a_h), b_a));
|
40 |
+
next_state = tf.add(tf.mul(tf.sub(1., z), state), tf.mul(z, a));
|
41 |
+
return [next_state, next_state];
|
42 |
+
};
|
43 |
+
|
44 |
+
|
45 |
+
var generate = function() {
|
46 |
+
cur_run = cur_run + 1;
|
47 |
+
setTimeout(function() {
|
48 |
+
var counter = 2000;
|
49 |
+
tf.disposeVariables();
|
50 |
+
|
51 |
+
tf.engine().startScope();
|
52 |
+
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
53 |
+
ctx.beginPath();
|
54 |
+
dojob(cur_run);
|
55 |
+
}, 200);
|
56 |
+
|
57 |
+
return false;
|
58 |
+
}
|
59 |
+
|
60 |
+
var dojob = function(run_id) {
|
61 |
+
var text = document.getElementById("user-input").value;
|
62 |
+
if (text.length == 0) {
|
63 |
+
text = "The quick brown fox jumps over the lazy dog";
|
64 |
+
}
|
65 |
+
|
66 |
+
var cur_x = 50.;
|
67 |
+
var cur_y = 300.;
|
68 |
+
|
69 |
+
|
70 |
+
log(text);
|
71 |
+
original_text = text;
|
72 |
+
text = '' + text + ' ' + text;
|
73 |
+
|
74 |
+
text = Array.from(text).map(function(e) {
|
75 |
+
return char2idx[e]
|
76 |
+
})
|
77 |
+
var text_embed = WEIGHTS['rnn/~/embed_1__embeddings'];
|
78 |
+
indices = tf.tensor1d(text, 'int32');
|
79 |
+
text = text_embed.gather(indices);
|
80 |
+
|
81 |
+
filter = WEIGHTS['rnn/~/conv1_d__w'];
|
82 |
+
embed = tf.conv1d(text, filter, 1, 'same');
|
83 |
+
bias = tf.expandDims(WEIGHTS['rnn/~/conv1_d__b'], 0);
|
84 |
+
embed = tf.add(embed, bias);
|
85 |
+
|
86 |
+
var writer_embed = WEIGHTS['rnn/~/embed__embeddings'];
|
87 |
+
var e = document.getElementById("writers");
|
88 |
+
var wid = parseInt(e.value);
|
89 |
+
// log(wid);
|
90 |
+
|
91 |
+
wid = tf.tensor1d([wid], 'int32');
|
92 |
+
wid = writer_embed.gather(wid);
|
93 |
+
embed = tf.add(wid, embed);
|
94 |
+
|
95 |
+
// initial state
|
96 |
+
var gru0_hx = tf.zeros([1, RNN_SIZE]);
|
97 |
+
var gru1_hx = tf.zeros([1, RNN_SIZE]);
|
98 |
+
// var gru2_hx = tf.zeros([1, RNN_SIZE]);
|
99 |
+
|
100 |
+
var att_location = tf.zeros([1, 1]);
|
101 |
+
var att_context = tf.zeros([1, 73]);
|
102 |
+
|
103 |
+
var input = tf.tensor([[0., 0., 1.]]);
|
104 |
+
|
105 |
+
gru0_w_h = WEIGHTS['rnn/~/lstm_attention_core/~/gru__w_h'];
|
106 |
+
gru0_w_i = WEIGHTS['rnn/~/lstm_attention_core/~/gru__w_i'];
|
107 |
+
gru0_bias = WEIGHTS['rnn/~/lstm_attention_core/~/gru__b'];
|
108 |
+
|
109 |
+
gru1_w_h = WEIGHTS['rnn/~/lstm_attention_core/~/gru_1__w_h'];
|
110 |
+
gru1_w_i = WEIGHTS['rnn/~/lstm_attention_core/~/gru_1__w_i'];
|
111 |
+
gru1_bias = WEIGHTS['rnn/~/lstm_attention_core/~/gru_1__b'];
|
112 |
+
att_w = WEIGHTS['rnn/~/lstm_attention_core/~/linear__w'];
|
113 |
+
att_b = WEIGHTS['rnn/~/lstm_attention_core/~/linear__b'];
|
114 |
+
gmm_w = WEIGHTS['rnn/~/linear__w'];
|
115 |
+
gmm_b = WEIGHTS['rnn/~/linear__b'];
|
116 |
+
|
117 |
+
ruler = tf.tensor([...Array(text.shape[0]).keys()]);
|
118 |
+
var bias = parseInt(document.getElementById("bias").value) / 100 * 3;
|
119 |
+
|
120 |
+
cur_x = 50.;
|
121 |
+
cur_y = 400.;
|
122 |
+
var path = [];
|
123 |
+
var dx = 0.;
|
124 |
+
var dy = 0;
|
125 |
+
var eos = 1.;
|
126 |
+
var counter = 0;
|
127 |
+
|
128 |
+
|
129 |
+
function loop(my_run_id) {
|
130 |
+
if (my_run_id < cur_run) {
|
131 |
+
tf.disposeVariables();
|
132 |
+
tf.engine().endScope();
|
133 |
+
return;
|
134 |
+
}
|
135 |
+
|
136 |
+
counter++;
|
137 |
+
if (counter < 2000) {
|
138 |
+
[att_location,att_context,gru0_hx,gru1_hx,input] = tf.tidy(function() {
|
139 |
+
// Attention
|
140 |
+
const inp_0 = tf.concat([att_context, input], 1);
|
141 |
+
gru0_hx_ = gru0_hx;
|
142 |
+
[out_0,gru0_hx] = gru_core(inp_0, [gru0_w_h, gru0_w_i, gru0_bias], gru0_hx, RNN_SIZE);
|
143 |
+
tf.dispose(gru0_hx_);
|
144 |
+
const att_inp = tf.concat([att_context, input, out_0], 1);
|
145 |
+
const att_params = tf.add(tf.matMul(att_inp, att_w), att_b);
|
146 |
+
[alpha,beta,kappa] = tf.split(tf.softplus(att_params), 3, 1);
|
147 |
+
att_location_ = att_location;
|
148 |
+
att_location = tf.add(att_location, tf.div(kappa, 25.));
|
149 |
+
tf.dispose(att_location_)
|
150 |
+
|
151 |
+
const phi = tf.mul(alpha, tf.exp(tf.div(tf.neg(tf.square(tf.sub(att_location, ruler))), beta)));
|
152 |
+
att_context_ = att_context;
|
153 |
+
att_context = tf.sum(tf.mul(tf.expandDims(phi, 2), tf.expandDims(embed, 0)), 1)
|
154 |
+
tf.dispose(att_context_);
|
155 |
+
|
156 |
+
const inp_1 = tf.concat([input, out_0, att_context], 1);
|
157 |
+
tf.dispose(input);
|
158 |
+
gru1_hx_ = gru1_hx;
|
159 |
+
[out_1,gru1_hx] = gru_core(inp_1, [gru1_w_h, gru1_w_i, gru1_bias], gru1_hx, RNN_SIZE);
|
160 |
+
tf.dispose(gru1_hx_);
|
161 |
+
|
162 |
+
// GMM
|
163 |
+
const gmm_params = tf.add(tf.matMul(out_1, gmm_w), gmm_b);
|
164 |
+
[x,y,logstdx,logstdy,angle,log_weight,eos_logit] = tf.split(gmm_params, [5, 5, 5, 5, 5, 5, 1], 1);
|
165 |
+
// log_weight = tf.softmax(log_weight, 1);
|
166 |
+
// log_weight = tf.log(log_weight);
|
167 |
+
// log_weight = tf.mul(log_weight, 1. + bias);
|
168 |
+
// const idx = tf.multinomial(log_weight, 1).dataSync()[0];
|
169 |
+
// log_weight = tf.softmax(log_weight, 1);
|
170 |
+
// log_weight = tf.log(log_weight);
|
171 |
+
// log_weight = tf.mul(log_weight, 1. + bias);
|
172 |
+
const idx = tf.argMax(log_weight, 1).dataSync()[0];
|
173 |
+
|
174 |
+
x = x.dataSync()[idx];
|
175 |
+
y = y.dataSync()[idx];
|
176 |
+
const stdx = tf.exp(tf.sub(logstdx, bias)).dataSync()[idx];
|
177 |
+
const stdy = tf.exp(tf.sub(logstdy, bias)).dataSync()[idx];
|
178 |
+
angle = angle.dataSync()[idx];
|
179 |
+
e = tf.sigmoid(tf.mul(eos_logit, (1. + 0.*bias))).dataSync()[0];
|
180 |
+
const rx = rand_truncated_normal(-5, 5) * stdx;
|
181 |
+
const ry = rand_truncated_normal(-5, 5) * stdy;
|
182 |
+
x = x + Math.cos(-angle) * rx - Math.sin(-angle) * ry;
|
183 |
+
y = y + Math.sin(-angle) * rx + Math.cos(-angle) * ry;
|
184 |
+
if (Math.random() < e) {
|
185 |
+
e = 1.;
|
186 |
+
} else {
|
187 |
+
e = 0.;
|
188 |
+
}
|
189 |
+
input = tf.tensor([[x, y, e]]);
|
190 |
+
return [att_location, att_context, gru0_hx, gru1_hx, input];
|
191 |
+
});
|
192 |
+
|
193 |
+
[dx,dy,eos_] = input.dataSync();
|
194 |
+
dy = -dy * 3;
|
195 |
+
dx = dx * 3;
|
196 |
+
if (eos == 0.) {
|
197 |
+
ctx.beginPath();
|
198 |
+
ctx.moveTo(cur_x, cur_y, 0, 0);
|
199 |
+
ctx.lineTo(cur_x + dx, cur_y + dy);
|
200 |
+
ctx.stroke();
|
201 |
+
}
|
202 |
+
eos = eos_;
|
203 |
+
cur_x = cur_x + dx;
|
204 |
+
cur_y = cur_y + dy;
|
205 |
+
|
206 |
+
if (att_location.dataSync()[0] < original_text.length + 2) {
|
207 |
+
setTimeout(function() {loop(my_run_id);}, 0);
|
208 |
+
}
|
209 |
+
}
|
210 |
+
}
|
211 |
+
|
212 |
+
loop(run_id);
|
213 |
+
}
|
214 |
+
|
215 |
+
window.onload = function(e) {
|
216 |
+
//Setting up canvas
|
217 |
+
canvas = document.getElementById("hw-canvas");
|
218 |
+
ctx = canvas.getContext("2d");
|
219 |
+
ctx.canvas.width = window.innerWidth;
|
220 |
+
ctx.canvas.height = window.innerHeight;
|
221 |
+
|
222 |
+
}
|