Update index.html
Browse files- index.html +677 -287
index.html
CHANGED
@@ -318,314 +318,703 @@
|
|
318 |
</div>
|
319 |
|
320 |
<script>
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
}
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
}
|
357 |
-
|
358 |
-
// ๐๏ธ Initialize biases with small positive values
|
359 |
-
const biases = Array(outputSize).fill(0.01);
|
360 |
-
this.biases.push(biases);
|
361 |
-
// ๐ Store the activation function for this layer
|
362 |
-
this.activations.push(activation);
|
363 |
}
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
return x > 0 ? scale * x : scale * alpha * (Math.exp(x) - 1); // ๐ Scaled Exponential Linear Unit
|
377 |
-
default:
|
378 |
-
throw new Error('Whoops! We don\'t know that activation function.');
|
379 |
}
|
380 |
}
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
}
|
|
|
|
|
|
|
|
|
398 |
}
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
}
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
}
|
487 |
-
|
488 |
-
}
|
489 |
-
trainError += batchError;
|
490 |
-
}
|
491 |
-
lastTrainLoss = trainError / trainSet.length;
|
492 |
-
// ๐งช Evaluate on test set if provided
|
493 |
-
if (testSet) {
|
494 |
-
let testError = 0;
|
495 |
-
for (const data of testSet) {
|
496 |
-
const prediction = this.predict(data.input);
|
497 |
-
testError += Math.abs(data.output[0] - prediction[0]);
|
498 |
}
|
499 |
-
|
500 |
}
|
501 |
-
//
|
502 |
-
|
503 |
-
|
|
|
|
|
|
|
|
|
504 |
}
|
505 |
-
|
506 |
-
|
507 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
508 |
}
|
509 |
-
//
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
515 |
}
|
|
|
516 |
}
|
517 |
-
|
518 |
-
// ๐งฎ Calculate total number of parameters
|
519 |
-
let totalParams = 0;
|
520 |
-
for (let i = 0; i < this.weights.length; i++) {
|
521 |
-
const weightLayer = this.weights[i];
|
522 |
-
const biasLayer = this.biases[i];
|
523 |
-
totalParams += weightLayer.flat().length + biasLayer.length;
|
524 |
-
}
|
525 |
-
// ๐ Create a summary of the training
|
526 |
-
const trainingSummary = {
|
527 |
-
trainLoss: lastTrainLoss,
|
528 |
-
testLoss: lastTestLoss,
|
529 |
-
parameters: totalParams,
|
530 |
-
training: {
|
531 |
-
time: end - start,
|
532 |
-
epochs,
|
533 |
-
learningRate,
|
534 |
-
batchSize
|
535 |
-
},
|
536 |
-
layers: this.layers.map(layer => ({
|
537 |
-
inputSize: layer.inputSize,
|
538 |
-
outputSize: layer.outputSize,
|
539 |
-
activation: layer.activation
|
540 |
-
}))
|
541 |
-
};
|
542 |
-
this.details = trainingSummary;
|
543 |
-
return trainingSummary;
|
544 |
}
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
const
|
555 |
-
|
556 |
-
for (let j = 0; j < weights.length; j++) {
|
557 |
-
const weight = weights[j];
|
558 |
-
let sum = biases[j];
|
559 |
-
for (let k = 0; k < layerInput.length; k++) {
|
560 |
-
sum += layerInput[k] * weight[k];
|
561 |
-
}
|
562 |
-
rawValues.push(sum);
|
563 |
-
layerOutput.push(this.activationFunction(sum, activation));
|
564 |
-
}
|
565 |
-
allRawValues.push(rawValues);
|
566 |
-
allActivations.push(layerOutput);
|
567 |
-
layerInput = layerOutput;
|
568 |
}
|
569 |
-
|
570 |
-
this.lastActivations = allActivations;
|
571 |
-
this.lastRawValues = allRawValues;
|
572 |
-
return layerInput;
|
573 |
}
|
574 |
-
//
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
details: this.details
|
582 |
-
};
|
583 |
-
const blob = new Blob([JSON.stringify(data)], {
|
584 |
-
type: 'application/json'
|
585 |
-
});
|
586 |
-
const url = URL.createObjectURL(blob);
|
587 |
-
const a = document.createElement('a');
|
588 |
-
a.href = url;
|
589 |
-
a.download = `${name}.json`;
|
590 |
-
a.click();
|
591 |
-
URL.revokeObjectURL(url);
|
592 |
}
|
593 |
-
//
|
594 |
-
|
595 |
-
|
596 |
-
const file = event.target.files[0];
|
597 |
-
if (!file) return;
|
598 |
-
const reader = new FileReader();
|
599 |
-
reader.onload = (event) => {
|
600 |
-
const text = event.target.result;
|
601 |
-
try {
|
602 |
-
const data = JSON.parse(text);
|
603 |
-
this.weights = data.weights;
|
604 |
-
this.biases = data.biases;
|
605 |
-
this.activations = data.activations;
|
606 |
-
this.layers = data.layers;
|
607 |
-
this.details = data.details;
|
608 |
-
callback();
|
609 |
-
if (this.debug === true) console.log('Model loaded successfully!');
|
610 |
-
input.removeEventListener('change', handleListener);
|
611 |
-
input.remove();
|
612 |
-
} catch (e) {
|
613 |
-
input.removeEventListener('change', handleListener);
|
614 |
-
input.remove();
|
615 |
-
if (this.debug === true) console.error('Failed to load model:', e);
|
616 |
-
}
|
617 |
-
};
|
618 |
-
reader.readAsText(file);
|
619 |
-
};
|
620 |
-
const input = document.createElement('input');
|
621 |
-
input.type = 'file';
|
622 |
-
input.accept = '.json';
|
623 |
-
input.style.opacity = '0';
|
624 |
-
document.body.append(input);
|
625 |
-
input.addEventListener('change', handleListener.bind(this));
|
626 |
-
input.click();
|
627 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
628 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
629 |
document.getElementById("loadDataBtn").onclick = () => {
|
630 |
document.getElementById('trainingData').value = `1.0, 0.0, 0.0, 0.0
|
631 |
0.7, 0.7, 0.8, 1
|
@@ -765,6 +1154,7 @@
|
|
765 |
const trainButton = document.getElementById('trainButton');
|
766 |
trainButton.disabled = true;
|
767 |
trainButton.textContent = 'training...';
|
|
|
768 |
const summary = await nn.train(trainingData, options);
|
769 |
trainButton.disabled = false;
|
770 |
trainButton.textContent = 'train';
|
|
|
318 |
</div>
|
319 |
|
320 |
<script>
|
321 |
+
class ReinforcementModule {
|
322 |
+
constructor(network, options = {}) {
|
323 |
+
this.network = network;
|
324 |
+
this.options = {
|
325 |
+
memorySize: options.memorySize || 1000,
|
326 |
+
batchSize: options.batchSize || 16,
|
327 |
+
learningRate: options.learningRate || 0.01,
|
328 |
+
gamma: options.gamma || 0.9,
|
329 |
+
epsilon: options.epsilon || 1,
|
330 |
+
epsilonMin: options.epsilonMin || 0.01,
|
331 |
+
epsilonDecay: options.epsilonDecay || 0.95,
|
332 |
+
weightUpdateRange: options.weightUpdateRange || 0.01,
|
333 |
+
actionSpace: options.actionSpace || 1024,
|
334 |
+
memoryLayerSize: options.memoryLayerSize || 128,
|
335 |
+
predictionHorizon: options.predictionHorizon || 1024,
|
336 |
+
memoryCellDecay: options.memoryCellDecay || 0.9
|
337 |
+
};
|
338 |
+
|
339 |
+
// Initialize memory cells
|
340 |
+
this.memoryCells = {
|
341 |
+
shortTerm: new Array(this.options.memoryLayerSize).fill(0),
|
342 |
+
longTerm: new Array(this.options.memoryLayerSize).fill(0),
|
343 |
+
cellState: new Array(this.options.memoryLayerSize).fill(0)
|
344 |
+
};
|
345 |
+
|
346 |
+
// Initialize gates and networks
|
347 |
+
this.gates = {
|
348 |
+
forget: this.createGateNetwork(this.options.memoryLayerSize),
|
349 |
+
input: this.createGateNetwork(this.options.memoryLayerSize),
|
350 |
+
output: this.createGateNetwork(this.options.memoryLayerSize),
|
351 |
+
candidates: this.createGateNetwork(this.options.memoryLayerSize)
|
352 |
+
};
|
353 |
+
|
354 |
+
this.memory = [];
|
355 |
+
this.currentState = this.getNetworkState();
|
356 |
+
this.bestWeights = this.cloneWeights(network.weights);
|
357 |
+
this.bestLoss = Infinity;
|
358 |
+
this.epsilon = this.options.epsilon;
|
359 |
+
|
360 |
+
this.qNetwork = this.createQNetwork();
|
361 |
+
this.outcomePredictor = this.createOutcomePredictor();
|
362 |
+
}
|
363 |
+
|
364 |
+
createGateNetwork(size) {
|
365 |
+
const gate = new carbono(false);
|
366 |
+
gate.layer(this.getFlattenedStateSize(), size, "sigmoid");
|
367 |
+
return gate;
|
368 |
+
}
|
369 |
+
|
370 |
+
createQNetwork() {
|
371 |
+
const qNet = new carbono(false);
|
372 |
+
const stateSize = this.getFlattenedStateSize();
|
373 |
+
const actionSize = this.getActionSpaceSize();
|
374 |
+
|
375 |
+
qNet.layer(stateSize + actionSize, 128, "selu");
|
376 |
+
qNet.layer(128, 64, "selu");
|
377 |
+
qNet.layer(64, 1, "selu");
|
378 |
+
|
379 |
+
return qNet;
|
380 |
+
}
|
381 |
+
|
382 |
+
createOutcomePredictor() {
|
383 |
+
const predictor = new carbono(false);
|
384 |
+
const inputSize =
|
385 |
+
this.getFlattenedStateSize() + this.options.memoryLayerSize * 3;
|
386 |
+
|
387 |
+
predictor.layer(inputSize, 8, "tanh");
|
388 |
+
predictor.layer(8, 8, "selu");
|
389 |
+
predictor.layer(8, this.options.predictionHorizon, "tanh");
|
390 |
+
|
391 |
+
return predictor;
|
392 |
+
}
|
393 |
+
|
394 |
+
getFlattenedStateSize() {
|
395 |
+
let size = 0;
|
396 |
+
this.network.weights.forEach((layer) => {
|
397 |
+
size += layer.flat().length;
|
398 |
+
});
|
399 |
+
return size + 3;
|
400 |
+
}
|
401 |
+
|
402 |
+
getActionSpaceSize() {
|
403 |
+
let size = 0;
|
404 |
+
this.network.weights.forEach((layer) => {
|
405 |
+
size += layer.flat().length * this.options.actionSpace;
|
406 |
+
});
|
407 |
+
return size;
|
408 |
+
}
|
409 |
+
|
410 |
+
getNetworkState() {
|
411 |
+
const flatWeights = this.network.weights
|
412 |
+
.map((layer) => layer.flat())
|
413 |
+
.flat();
|
414 |
+
return [...flatWeights, this.bestLoss, this.getCurrentLoss(), this.epsilon];
|
415 |
+
}
|
416 |
+
|
417 |
+
async getCurrentLoss() {
|
418 |
+
let totalLoss = 0;
|
419 |
+
for (const data of this.network.trainingData) {
|
420 |
+
const prediction = this.network.predict(data.input);
|
421 |
+
totalLoss += Math.abs(prediction[0] - data.output[0]);
|
422 |
+
}
|
423 |
+
return totalLoss / this.network.trainingData.length;
|
424 |
+
}
|
425 |
+
|
426 |
+
async updateMemoryCells(state) {
|
427 |
+
const forgetGate = this.gates.forget.predict(state);
|
428 |
+
const inputGate = this.gates.input.predict(state);
|
429 |
+
const outputGate = this.gates.output.predict(state);
|
430 |
+
const candidates = this.gates.candidates.predict(state);
|
431 |
+
|
432 |
+
for (let i = 0; i < this.options.memoryLayerSize; i++) {
|
433 |
+
this.memoryCells.cellState[i] *= forgetGate[i];
|
434 |
+
this.memoryCells.cellState[i] += inputGate[i] * candidates[i];
|
435 |
+
this.memoryCells.shortTerm[i] =
|
436 |
+
Math.tanh(this.memoryCells.cellState[i]) * outputGate[i];
|
437 |
+
this.memoryCells.longTerm[i] =
|
438 |
+
this.memoryCells.longTerm[i] * this.options.memoryCellDecay +
|
439 |
+
this.memoryCells.shortTerm[i] * (1 - this.options.memoryCellDecay);
|
440 |
+
}
|
441 |
+
}
|
442 |
+
|
443 |
+
async predictOutcomes(state) {
|
444 |
+
const input = [
|
445 |
+
...state,
|
446 |
+
...this.memoryCells.shortTerm,
|
447 |
+
...this.memoryCells.longTerm,
|
448 |
+
...this.memoryCells.cellState
|
449 |
+
];
|
450 |
+
return this.outcomePredictor.predict(input);
|
451 |
+
}
|
452 |
+
|
453 |
+
encodeAction(action) {
|
454 |
+
const encoded = new Array(this.getActionSpaceSize()).fill(0);
|
455 |
+
encoded[action] = 1;
|
456 |
+
return encoded;
|
457 |
+
}
|
458 |
+
|
459 |
+
async predictQValue(state, action) {
|
460 |
+
const encoded = this.encodeAction(action);
|
461 |
+
const input = [...state, ...encoded];
|
462 |
+
const qValue = this.qNetwork.predict(input);
|
463 |
+
return qValue[0];
|
464 |
+
}
|
465 |
+
|
466 |
+
simulateAction(state, action) {
|
467 |
+
const simState = [...state];
|
468 |
+
const updates = this.actionToWeightUpdates(action);
|
469 |
+
let stateIndex = 0;
|
470 |
+
|
471 |
+
for (const layer of updates) {
|
472 |
+
for (const row of layer) {
|
473 |
+
for (const update of row) {
|
474 |
+
simState[stateIndex] += update;
|
475 |
+
stateIndex++;
|
476 |
}
|
477 |
+
}
|
478 |
+
}
|
479 |
+
|
480 |
+
return simState;
|
481 |
+
}
|
482 |
+
|
483 |
+
async selectAction() {
|
484 |
+
if (Math.random() < this.epsilon) {
|
485 |
+
return Math.floor(Math.random() * this.getActionSpaceSize());
|
486 |
+
}
|
487 |
+
|
488 |
+
const state = this.getNetworkState();
|
489 |
+
await this.updateMemoryCells(state);
|
490 |
+
|
491 |
+
let bestAction = 0;
|
492 |
+
let bestOutcome = -Infinity;
|
493 |
+
|
494 |
+
for (let action = 0; action < this.getActionSpaceSize(); action++) {
|
495 |
+
const simState = this.simulateAction(state, action);
|
496 |
+
const outcomes = await this.predictOutcomes(simState);
|
497 |
+
|
498 |
+
const expectedValue = outcomes.reduce((sum, val, i) => {
|
499 |
+
return sum + val * Math.pow(this.options.gamma, i);
|
500 |
+
}, 0);
|
501 |
+
|
502 |
+
if (expectedValue > bestOutcome) {
|
503 |
+
bestOutcome = expectedValue;
|
504 |
+
bestAction = action;
|
505 |
+
}
|
506 |
+
}
|
507 |
+
|
508 |
+
return bestAction;
|
509 |
+
}
|
510 |
+
|
511 |
+
actionToWeightUpdates(action) {
|
512 |
+
const updates = [];
|
513 |
+
let actionIndex = action;
|
514 |
+
|
515 |
+
for (const layer of this.network.weights) {
|
516 |
+
const layerUpdate = [];
|
517 |
+
for (let i = 0; i < layer.length; i++) {
|
518 |
+
const rowUpdate = [];
|
519 |
+
for (let j = 0; j < layer[i].length; j++) {
|
520 |
+
const actionValue = actionIndex % this.options.actionSpace;
|
521 |
+
actionIndex = Math.floor(actionIndex / this.options.actionSpace);
|
522 |
+
const update =
|
523 |
+
((actionValue / (this.options.actionSpace - 1)) * 2 - 1) *
|
524 |
+
this.options.weightUpdateRange;
|
525 |
+
rowUpdate.push(update);
|
526 |
}
|
527 |
+
layerUpdate.push(rowUpdate);
|
|
|
|
|
|
|
|
|
|
|
528 |
}
|
529 |
+
updates.push(layerUpdate);
|
530 |
+
}
|
531 |
+
|
532 |
+
return updates;
|
533 |
+
}
|
534 |
+
|
535 |
+
async applyAction(action) {
|
536 |
+
const updates = this.actionToWeightUpdates(action);
|
537 |
+
for (let i = 0; i < this.network.weights.length; i++) {
|
538 |
+
for (let j = 0; j < this.network.weights[i].length; j++) {
|
539 |
+
for (let k = 0; k < this.network.weights[i][j].length; k++) {
|
540 |
+
this.network.weights[i][j][k] += updates[i][j][k];
|
|
|
|
|
|
|
541 |
}
|
542 |
}
|
543 |
+
}
|
544 |
+
}
|
545 |
+
|
546 |
+
calculateReward(oldLoss, newLoss) {
|
547 |
+
const improvement = oldLoss - newLoss;
|
548 |
+
const bestReward = newLoss < this.bestLoss ? 1.0 : 0.0;
|
549 |
+
return improvement + bestReward;
|
550 |
+
}
|
551 |
+
|
552 |
+
async getActualOutcomes(state, steps) {
|
553 |
+
const outcomes = [];
|
554 |
+
let currentState = state;
|
555 |
+
|
556 |
+
for (let i = 0; i < steps; i++) {
|
557 |
+
const loss = await this.getCurrentLoss();
|
558 |
+
outcomes.push(loss);
|
559 |
+
const action = await this.selectAction();
|
560 |
+
currentState = this.simulateAction(currentState, action);
|
561 |
+
}
|
562 |
+
|
563 |
+
return outcomes;
|
564 |
+
}
|
565 |
+
|
566 |
+
async trainOutcomePredictor(experience) {
|
567 |
+
const { state, nextState } = experience;
|
568 |
+
const actualOutcomes = await this.getActualOutcomes(
|
569 |
+
nextState,
|
570 |
+
this.options.predictionHorizon
|
571 |
+
);
|
572 |
+
|
573 |
+
const input = [
|
574 |
+
...state,
|
575 |
+
...this.memoryCells.shortTerm,
|
576 |
+
...this.memoryCells.longTerm,
|
577 |
+
...this.memoryCells.cellState
|
578 |
+
];
|
579 |
+
|
580 |
+
await this.outcomePredictor.train(
|
581 |
+
[
|
582 |
+
{
|
583 |
+
input: input,
|
584 |
+
output: actualOutcomes
|
585 |
}
|
586 |
+
],
|
587 |
+
{
|
588 |
+
epochs: 1000,
|
589 |
+
learningRate: this.options.learningRate
|
590 |
}
|
591 |
+
);
|
592 |
+
}
|
593 |
+
|
594 |
+
async trainQNetwork(batch) {
|
595 |
+
for (const experience of batch) {
|
596 |
+
const { state, action, reward, nextState } = experience;
|
597 |
+
const currentQ = await this.predictQValue(state, action);
|
598 |
+
|
599 |
+
let maxNextQ = -Infinity;
|
600 |
+
for (let a = 0; a < this.getActionSpaceSize(); a++) {
|
601 |
+
const nextQ = await this.predictQValue(nextState, a);
|
602 |
+
maxNextQ = Math.max(maxNextQ, nextQ);
|
603 |
+
}
|
604 |
+
|
605 |
+
const targetQ = reward + this.options.gamma * maxNextQ;
|
606 |
+
const input = [...state, ...this.encodeAction(action)];
|
607 |
+
|
608 |
+
await this.qNetwork.train(
|
609 |
+
[
|
610 |
+
{
|
611 |
+
input: input,
|
612 |
+
output: [targetQ]
|
613 |
+
}
|
614 |
+
],
|
615 |
+
{
|
616 |
+
epochs: 100,
|
617 |
+
learningRate: this.options.learningRate
|
618 |
}
|
619 |
+
);
|
620 |
+
}
|
621 |
+
}
|
622 |
+
|
623 |
+
async update(currentLoss) {
|
624 |
+
const state = this.getNetworkState();
|
625 |
+
const action = await this.selectAction();
|
626 |
+
await this.applyAction(action);
|
627 |
+
const nextState = this.getNetworkState();
|
628 |
+
const newLoss = await this.getCurrentLoss();
|
629 |
+
const reward = this.calculateReward(currentLoss, newLoss);
|
630 |
+
|
631 |
+
const experience = {
|
632 |
+
state,
|
633 |
+
action,
|
634 |
+
reward,
|
635 |
+
nextState
|
636 |
+
};
|
637 |
+
|
638 |
+
this.memory.push(experience);
|
639 |
+
await this.trainOutcomePredictor(experience);
|
640 |
+
|
641 |
+
if (this.memory.length > this.options.memorySize) {
|
642 |
+
this.memory.shift();
|
643 |
+
}
|
644 |
+
|
645 |
+
if (this.memory.length >= this.options.batchSize) {
|
646 |
+
const batch = [];
|
647 |
+
for (let i = 0; i < this.options.batchSize; i++) {
|
648 |
+
const index = Math.floor(Math.random() * this.memory.length);
|
649 |
+
batch.push(this.memory[index]);
|
650 |
+
}
|
651 |
+
await this.trainQNetwork(batch);
|
652 |
+
}
|
653 |
+
|
654 |
+
if (newLoss < this.bestLoss) {
|
655 |
+
this.bestLoss = newLoss;
|
656 |
+
this.bestWeights = this.cloneWeights(this.network.weights);
|
657 |
+
}
|
658 |
+
|
659 |
+
this.epsilon = Math.max(
|
660 |
+
this.options.epsilonMin,
|
661 |
+
this.epsilon * this.options.epsilonDecay
|
662 |
+
);
|
663 |
+
|
664 |
+
return {
|
665 |
+
loss: newLoss,
|
666 |
+
bestLoss: this.bestLoss,
|
667 |
+
epsilon: this.epsilon
|
668 |
+
};
|
669 |
+
}
|
670 |
+
|
671 |
+
cloneWeights(weights) {
|
672 |
+
return weights.map((layer) => layer.map((row) => [...row]));
|
673 |
+
}
|
674 |
+
}
|
675 |
+
// ๐ง carbono: A Fun and Friendly Neural Network Class ๐ง
|
676 |
+
// This micro-library wraps everything you need to have
|
677 |
+
// This is the simplest yet functional feedforward mlp in js
|
678 |
+
class carbono {
|
679 |
+
constructor(debug = true) {
|
680 |
+
this.layers = []; // ๐ Stores info about each layer
|
681 |
+
this.weights = []; // โ๏ธ Stores weights for each layer
|
682 |
+
this.biases = []; // ๐ง Stores biases for each layer
|
683 |
+
this.activations = []; // ๐ Stores activation functions for each layer
|
684 |
+
this.details = {}; // ๐ Stores details about the model
|
685 |
+
this.debug = debug; // ๐ Enables or disables debug messages
|
686 |
+
}
|
687 |
+
|
688 |
+
// ๐ฎ Initialize reinforcement learning module
|
689 |
+
play(options = {}) {
|
690 |
+
console.log("Reinforcement Learning Activated");
|
691 |
+
this.rl = new ReinforcementModule(this, options);
|
692 |
+
return this.rl;
|
693 |
+
}
|
694 |
+
|
695 |
+
// ๐๏ธ Add a new layer to the neural network
|
696 |
+
layer(inputSize, outputSize, activation = "tanh") {
|
697 |
+
// ๐งฑ Store layer information
|
698 |
+
this.layers.push({
|
699 |
+
inputSize,
|
700 |
+
outputSize,
|
701 |
+
activation
|
702 |
+
});
|
703 |
+
// ๐ Check if the new layer's input size matches the previous layer's output size
|
704 |
+
if (this.weights.length > 0) {
|
705 |
+
const lastLayerOutputSize = this.layers[this.layers.length - 2]
|
706 |
+
.outputSize;
|
707 |
+
if (inputSize !== lastLayerOutputSize) {
|
708 |
+
throw new Error(
|
709 |
+
"Oops! The input size of the new layer must match the output size of the previous layer."
|
710 |
+
);
|
711 |
+
}
|
712 |
+
}
|
713 |
+
// ๐ฒ Initialize weights using Xavier/Glorot initialization
|
714 |
+
const weights = [];
|
715 |
+
for (let i = 0; i < outputSize; i++) {
|
716 |
+
const row = [];
|
717 |
+
for (let j = 0; j < inputSize; j++) {
|
718 |
+
row.push(
|
719 |
+
(Math.random() - 0.5) * 2 * Math.sqrt(6 / (inputSize + outputSize))
|
720 |
+
);
|
721 |
+
}
|
722 |
+
weights.push(row);
|
723 |
+
}
|
724 |
+
this.weights.push(weights);
|
725 |
+
// ๐๏ธ Initialize biases with small positive values
|
726 |
+
const biases = Array(outputSize).fill(0.01);
|
727 |
+
this.biases.push(biases);
|
728 |
+
// ๐ Store the activation function for this layer
|
729 |
+
this.activations.push(activation);
|
730 |
+
}
|
731 |
+
// ๐งฎ Apply the activation function
|
732 |
+
activationFunction(x, activation) {
|
733 |
+
switch (activation) {
|
734 |
+
case "tanh":
|
735 |
+
return Math.tanh(x); // ใฐ๏ธ Hyperbolic tangent
|
736 |
+
case "sigmoid":
|
737 |
+
return 1 / (1 + Math.exp(-x)); // ๐ S-shaped curve
|
738 |
+
case "relu":
|
739 |
+
return Math.max(0, x); // ๐ Rectified Linear Unit
|
740 |
+
case "selu":
|
741 |
+
const alpha = 1.67326;
|
742 |
+
const scale = 1.0507;
|
743 |
+
return x > 0 ? scale * x : scale * alpha * (Math.exp(x) - 1); // ๐ Scaled Exponential Linear Unit
|
744 |
+
default:
|
745 |
+
throw new Error("Whoops! We don't know that activation function.");
|
746 |
+
}
|
747 |
+
}
|
748 |
+
// ๐ Calculate the derivative of the activation function
|
749 |
+
activationDerivative(x, activation) {
|
750 |
+
switch (activation) {
|
751 |
+
case "tanh":
|
752 |
+
return 1 - Math.pow(Math.tanh(x), 2);
|
753 |
+
case "sigmoid":
|
754 |
+
const sigmoid = 1 / (1 + Math.exp(-x));
|
755 |
+
return sigmoid * (1 - sigmoid);
|
756 |
+
case "relu":
|
757 |
+
return x > 0 ? 1 : 0;
|
758 |
+
case "selu":
|
759 |
+
const alpha = 1.67326;
|
760 |
+
const scale = 1.0507;
|
761 |
+
return x > 0 ? scale : scale * alpha * Math.exp(x);
|
762 |
+
default:
|
763 |
+
throw new Error(
|
764 |
+
"Oops! We don't know the derivative of that activation function."
|
765 |
+
);
|
766 |
+
}
|
767 |
+
}
|
768 |
+
// ๐๏ธโโ๏ธ Train the neural network
|
769 |
+
async train(trainSet, options = {}) {
|
770 |
+
// ๐๏ธ Set up training options with default values
|
771 |
+
const {
|
772 |
+
epochs = 200, // ๐ Number of times to go through the entire dataset
|
773 |
+
learningRate = 0.212, // ๐ How big of steps to take when adjusting weights
|
774 |
+
batchSize = 16, // ๐ฆ Number of samples to process before updating weights
|
775 |
+
printEveryEpochs = 100, // ๐จ๏ธ How often to print progress
|
776 |
+
earlyStopThreshold = 1e-6, // ๐ When to stop if the error is small enough
|
777 |
+
testSet = null, // ๐งช Optional test set for evaluation
|
778 |
+
callback = null // ๐ก Callback function for real-time updates
|
779 |
+
} = options;
|
780 |
+
const start = Date.now(); // โฑ๏ธ Start the timer
|
781 |
+
// ๐ก๏ธ Make sure batch size is at least 2
|
782 |
+
if (batchSize < 1) batchSize = 2;
|
783 |
+
// ๐๏ธ Automatically create layers if none exist
|
784 |
+
if (this.layers.length === 0) {
|
785 |
+
const numInputs = trainSet[0].input.length;
|
786 |
+
this.layer(numInputs, numInputs, "tanh");
|
787 |
+
this.layer(numInputs, 1, "tanh");
|
788 |
+
}
|
789 |
+
let lastTrainLoss = 0;
|
790 |
+
let lastTestLoss = null;
|
791 |
+
// ๐ Main training loop
|
792 |
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
793 |
+
let trainError = 0;
|
794 |
+
// ๐ฆ Process data in batches
|
795 |
+
for (let b = 0; b < trainSet.length; b += batchSize) {
|
796 |
+
const batch = trainSet.slice(b, b + batchSize);
|
797 |
+
let batchError = 0;
|
798 |
+
// ๐ง Forward pass and backward pass for each item in the batch
|
799 |
+
for (const data of batch) {
|
800 |
+
// ๐โโ๏ธ Forward pass
|
801 |
+
const layerInputs = [data.input];
|
802 |
+
for (let i = 0; i < this.weights.length; i++) {
|
803 |
+
const inputs = layerInputs[i];
|
804 |
+
const weights = this.weights[i];
|
805 |
+
const biases = this.biases[i];
|
806 |
+
const activation = this.activations[i];
|
807 |
+
const outputs = [];
|
808 |
+
for (let j = 0; j < weights.length; j++) {
|
809 |
+
const weight = weights[j];
|
810 |
+
let sum = biases[j];
|
811 |
+
for (let k = 0; k < inputs.length; k++) {
|
812 |
+
sum += inputs[k] * weight[k];
|
813 |
}
|
814 |
+
outputs.push(this.activationFunction(sum, activation));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
815 |
}
|
816 |
+
layerInputs.push(outputs);
|
817 |
}
|
818 |
+
// ๐ Backward pass
|
819 |
+
const outputLayerIndex = this.weights.length - 1;
|
820 |
+
const outputLayerInputs = layerInputs[layerInputs.length - 1];
|
821 |
+
const outputErrors = [];
|
822 |
+
for (let i = 0; i < outputLayerInputs.length; i++) {
|
823 |
+
const error = data.output[i] - outputLayerInputs[i];
|
824 |
+
outputErrors.push(error);
|
825 |
}
|
826 |
+
let layerErrors = [outputErrors];
|
827 |
+
for (let i = this.weights.length - 2; i >= 0; i--) {
|
828 |
+
const nextLayerWeights = this.weights[i + 1];
|
829 |
+
const nextLayerErrors = layerErrors[0];
|
830 |
+
const currentLayerInputs = layerInputs[i + 1];
|
831 |
+
const currentActivation = this.activations[i];
|
832 |
+
const errors = [];
|
833 |
+
for (let j = 0; j < this.layers[i].outputSize; j++) {
|
834 |
+
let error = 0;
|
835 |
+
for (let k = 0; k < this.layers[i + 1].outputSize; k++) {
|
836 |
+
error += nextLayerErrors[k] * nextLayerWeights[k][j];
|
837 |
+
}
|
838 |
+
errors.push(
|
839 |
+
error *
|
840 |
+
this.activationDerivative(
|
841 |
+
currentLayerInputs[j],
|
842 |
+
currentActivation
|
843 |
+
)
|
844 |
+
);
|
845 |
+
}
|
846 |
+
layerErrors.unshift(errors);
|
847 |
}
|
848 |
+
// ๐ง Update weights and biases
|
849 |
+
for (let i = 0; i < this.weights.length; i++) {
|
850 |
+
const inputs = layerInputs[i];
|
851 |
+
const errors = layerErrors[i];
|
852 |
+
const weights = this.weights[i];
|
853 |
+
const biases = this.biases[i];
|
854 |
+
for (let j = 0; j < weights.length; j++) {
|
855 |
+
const weight = weights[j];
|
856 |
+
for (let k = 0; k < inputs.length; k++) {
|
857 |
+
weight[k] += learningRate * errors[j] * inputs[k];
|
858 |
+
}
|
859 |
+
biases[j] += learningRate * errors[j];
|
860 |
+
}
|
861 |
}
|
862 |
+
batchError += Math.abs(outputErrors[0]); // Assuming binary output
|
863 |
}
|
864 |
+
trainError += batchError;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
865 |
}
|
866 |
+
lastTrainLoss = trainError / trainSet.length;
|
867 |
+
// ๐ฎ Apply reinforcement learning if initialized
|
868 |
+
if (this.rl) {
|
869 |
+
this.rl.update(lastTrainLoss);
|
870 |
+
}
|
871 |
+
// ๐งช Evaluate on test set if provided
|
872 |
+
if (testSet) {
|
873 |
+
let testError = 0;
|
874 |
+
for (const data of testSet) {
|
875 |
+
const prediction = this.predict(data.input);
|
876 |
+
testError += Math.abs(data.output[0] - prediction[0]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
877 |
}
|
878 |
+
lastTestLoss = testError / testSet.length;
|
|
|
|
|
|
|
879 |
}
|
880 |
+
// ๐ข Print progress if needed
|
881 |
+
if ((epoch + 1) % printEveryEpochs === 0 && this.debug === true) {
|
882 |
+
console.log(
|
883 |
+
`Epoch ${epoch + 1}, Train Loss: ${lastTrainLoss.toFixed(6)}${
|
884 |
+
testSet ? `, Test Loss: ${lastTestLoss.toFixed(6)}` : ""
|
885 |
+
}`
|
886 |
+
);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
887 |
}
|
888 |
+
// ๐ก Call the callback function with current progress
|
889 |
+
if (callback) {
|
890 |
+
await callback(epoch + 1, lastTrainLoss, lastTestLoss);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
891 |
}
|
892 |
+
// Add a small delay to prevent UI freezing
|
893 |
+
await new Promise((resolve) => setTimeout(resolve, 0));
|
894 |
+
// ๐ Check for early stopping
|
895 |
+
if (lastTrainLoss < earlyStopThreshold) {
|
896 |
+
console.log(
|
897 |
+
`We stopped at epoch ${
|
898 |
+
epoch + 1
|
899 |
+
} with train loss: ${lastTrainLoss.toFixed(6)}${
|
900 |
+
testSet ? ` and test loss: ${lastTestLoss.toFixed(6)}` : ""
|
901 |
+
}`
|
902 |
+
);
|
903 |
+
break;
|
904 |
+
}
|
905 |
+
}
|
906 |
+
const end = Date.now(); // โฑ๏ธ Stop the timer
|
907 |
+
// ๐งฎ Calculate total number of parameters
|
908 |
+
let totalParams = 0;
|
909 |
+
for (let i = 0; i < this.weights.length; i++) {
|
910 |
+
const weightLayer = this.weights[i];
|
911 |
+
const biasLayer = this.biases[i];
|
912 |
+
totalParams += weightLayer.flat().length + biasLayer.length;
|
913 |
}
|
914 |
+
// ๐ Create a summary of the training
|
915 |
+
const trainingSummary = {
|
916 |
+
trainLoss: lastTrainLoss,
|
917 |
+
testLoss: lastTestLoss,
|
918 |
+
parameters: totalParams,
|
919 |
+
training: {
|
920 |
+
time: end - start,
|
921 |
+
epochs,
|
922 |
+
learningRate,
|
923 |
+
batchSize
|
924 |
+
},
|
925 |
+
layers: this.layers.map((layer) => ({
|
926 |
+
inputSize: layer.inputSize,
|
927 |
+
outputSize: layer.outputSize,
|
928 |
+
activation: layer.activation
|
929 |
+
}))
|
930 |
+
};
|
931 |
+
this.details = trainingSummary;
|
932 |
+
return trainingSummary;
|
933 |
+
}
|
934 |
+
// ๐ฎ Use the trained network to make predictions
|
935 |
+
predict(input) {
|
936 |
+
let layerInput = input;
|
937 |
+
const allActivations = [input]; // Track all activations through layers
|
938 |
+
const allRawValues = []; // Track pre-activation values
|
939 |
+
for (let i = 0; i < this.weights.length; i++) {
|
940 |
+
const weights = this.weights[i];
|
941 |
+
const biases = this.biases[i];
|
942 |
+
const activation = this.activations[i];
|
943 |
+
const layerOutput = [];
|
944 |
+
const rawValues = [];
|
945 |
+
for (let j = 0; j < weights.length; j++) {
|
946 |
+
const weight = weights[j];
|
947 |
+
let sum = biases[j];
|
948 |
+
for (let k = 0; k < layerInput.length; k++) {
|
949 |
+
sum += layerInput[k] * weight[k];
|
950 |
+
}
|
951 |
+
rawValues.push(sum);
|
952 |
+
layerOutput.push(this.activationFunction(sum, activation));
|
953 |
+
}
|
954 |
+
allRawValues.push(rawValues);
|
955 |
+
allActivations.push(layerOutput);
|
956 |
+
layerInput = layerOutput;
|
957 |
+
}
|
958 |
+
// Store last activation values for visualization
|
959 |
+
this.lastActivations = allActivations;
|
960 |
+
this.lastRawValues = allRawValues;
|
961 |
+
return layerInput;
|
962 |
+
}
|
963 |
+
// ๐พ Save the model to a file
|
964 |
+
save(name = "model") {
|
965 |
+
const data = {
|
966 |
+
weights: this.weights,
|
967 |
+
biases: this.biases,
|
968 |
+
activations: this.activations,
|
969 |
+
layers: this.layers,
|
970 |
+
details: this.details
|
971 |
+
};
|
972 |
+
const blob = new Blob([JSON.stringify(data)], {
|
973 |
+
type: "application/json"
|
974 |
+
});
|
975 |
+
const url = URL.createObjectURL(blob);
|
976 |
+
const a = document.createElement("a");
|
977 |
+
a.href = url;
|
978 |
+
a.download = `${name}.json`;
|
979 |
+
a.click();
|
980 |
+
URL.revokeObjectURL(url);
|
981 |
+
}
|
982 |
+
// ๐ Load a saved model from a file
|
983 |
+
load(callback) {
|
984 |
+
const handleListener = (event) => {
|
985 |
+
const file = event.target.files[0];
|
986 |
+
if (!file) return;
|
987 |
+
const reader = new FileReader();
|
988 |
+
reader.onload = (event) => {
|
989 |
+
const text = event.target.result;
|
990 |
+
try {
|
991 |
+
const data = JSON.parse(text);
|
992 |
+
this.weights = data.weights;
|
993 |
+
this.biases = data.biases;
|
994 |
+
this.activations = data.activations;
|
995 |
+
this.layers = data.layers;
|
996 |
+
this.details = data.details;
|
997 |
+
callback();
|
998 |
+
if (this.debug === true) console.log("Model loaded successfully!");
|
999 |
+
input.removeEventListener("change", handleListener);
|
1000 |
+
input.remove();
|
1001 |
+
} catch (e) {
|
1002 |
+
input.removeEventListener("change", handleListener);
|
1003 |
+
input.remove();
|
1004 |
+
if (this.debug === true) console.error("Failed to load model:", e);
|
1005 |
+
}
|
1006 |
+
};
|
1007 |
+
reader.readAsText(file);
|
1008 |
+
};
|
1009 |
+
const input = document.createElement("input");
|
1010 |
+
input.type = "file";
|
1011 |
+
input.accept = ".json";
|
1012 |
+
input.style.opacity = "0";
|
1013 |
+
document.body.append(input);
|
1014 |
+
input.addEventListener("change", handleListener.bind(this));
|
1015 |
+
input.click();
|
1016 |
+
}
|
1017 |
+
}
|
1018 |
document.getElementById("loadDataBtn").onclick = () => {
|
1019 |
document.getElementById('trainingData').value = `1.0, 0.0, 0.0, 0.0
|
1020 |
0.7, 0.7, 0.8, 1
|
|
|
1154 |
const trainButton = document.getElementById('trainButton');
|
1155 |
trainButton.disabled = true;
|
1156 |
trainButton.textContent = 'training...';
|
1157 |
+
nn.play()
|
1158 |
const summary = await nn.train(trainingData, options);
|
1159 |
trainButton.disabled = false;
|
1160 |
trainButton.textContent = 'train';
|