|
"""Neural network implementation for BackpropNEAT.""" |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
from typing import Dict, List, Optional, Tuple, Union |
|
from .genome import Genome |
|
import copy |
|
import random |
|
|
|
class Network: |
|
"""Neural network for NEAT implementation. |
|
Implements a strictly feed-forward network following original NEAT principles: |
|
1. Start minimal - direct input-output connections only |
|
2. Complexify gradually through structural mutations |
|
3. Protect innovation through speciation |
|
4. No recurrent connections (as per requirements) |
|
""" |
|
def __init__(self, genome: Genome): |
|
"""Initialize network from genome.""" |
|
|
|
self.genome = genome |
|
|
|
|
|
if genome.input_size != 12 or genome.output_size != 3: |
|
print(f"Warning: Genome size mismatch. Expected 12 inputs, 3 outputs. Got {genome.input_size} inputs, {genome.output_size} outputs") |
|
genome.input_size = 12 |
|
genome.output_size = 3 |
|
|
|
self.input_size = 12 |
|
self.output_size = 3 |
|
|
|
|
|
self.node_genes = {} |
|
self.connection_genes = [] |
|
|
|
|
|
for i in range(12): |
|
self.node_genes[i] = NodeGene(i, 'input', 'linear') |
|
|
|
|
|
self.node_genes[12] = NodeGene(12, 'bias', 'linear') |
|
|
|
|
|
for i in range(3): |
|
node_id = 13 + i |
|
self.node_genes[node_id] = NodeGene(node_id, 'output', 'sigmoid') |
|
|
|
|
|
if i < 2: |
|
self.connection_genes.append( |
|
ConnectionGene(12, node_id, random.uniform(0.0, 1.0), True) |
|
) |
|
else: |
|
self.connection_genes.append( |
|
ConnectionGene(12, node_id, random.uniform(-0.5, 0.5), True) |
|
) |
|
|
|
|
|
if i == 0: |
|
self.connection_genes.append( |
|
ConnectionGene(0, node_id, random.uniform(0.5, 1.5), True) |
|
) |
|
self.connection_genes.append( |
|
ConnectionGene(2, node_id, random.uniform(0.5, 1.5), True) |
|
) |
|
elif i == 1: |
|
self.connection_genes.append( |
|
ConnectionGene(0, node_id, random.uniform(-1.5, -0.5), True) |
|
) |
|
self.connection_genes.append( |
|
ConnectionGene(2, node_id, random.uniform(-1.5, -0.5), True) |
|
) |
|
else: |
|
self.connection_genes.append( |
|
ConnectionGene(1, node_id, random.uniform(-1.5, -0.5), True) |
|
) |
|
self.connection_genes.append( |
|
ConnectionGene(3, node_id, random.uniform(-1.0, 0.0), True) |
|
) |
|
|
|
|
|
for node_id, node in genome.node_genes.items(): |
|
if node_id not in self.node_genes: |
|
self.node_genes[node_id] = NodeGene( |
|
node_id, |
|
node.node_type, |
|
node.activation |
|
) |
|
|
|
|
|
if genome.connection_genes: |
|
|
|
self.connection_genes = [] |
|
for conn in genome.connection_genes: |
|
|
|
if conn.source not in self.node_genes or conn.target not in self.node_genes: |
|
print(f"Warning: Connection {conn.source}->{conn.target} references missing nodes") |
|
continue |
|
self.connection_genes.append(ConnectionGene( |
|
conn.source, |
|
conn.target, |
|
conn.weight, |
|
conn.enabled |
|
)) |
|
|
|
|
|
for output_id in [13, 14, 15]: |
|
has_connection = False |
|
for conn in self.connection_genes: |
|
if conn.enabled and conn.target == output_id: |
|
has_connection = True |
|
break |
|
|
|
if not has_connection: |
|
print(f"Adding missing connections for output {output_id}") |
|
|
|
self.connection_genes.append( |
|
ConnectionGene(12, output_id, random.uniform(-1.0, 1.0), True) |
|
) |
|
|
|
input_id = random.randint(0, 11) |
|
self.connection_genes.append( |
|
ConnectionGene(input_id, output_id, random.uniform(-1.0, 1.0), True) |
|
) |
|
|
|
|
|
self.node_evals = {} |
|
self._build_feed_forward_order() |
|
|
|
|
|
self._verify_outputs() |
|
|
|
def _verify_outputs(self): |
|
"""Verify all outputs have valid connections and evaluations.""" |
|
output_ids = {13, 14, 15} |
|
|
|
|
|
for output_id in output_ids: |
|
if output_id not in self.node_evals: |
|
print(f"Adding missing evaluation for output {output_id}") |
|
bias_id = 12 |
|
self.node_evals[output_id] = { |
|
'inputs': [bias_id], |
|
'weights': [1.0], |
|
'activation': 'sigmoid' |
|
} |
|
|
|
if not any(c.target == output_id and c.enabled for c in self.connection_genes): |
|
self.connection_genes.append( |
|
ConnectionGene(bias_id, output_id, 1.0, True) |
|
) |
|
|
|
def _create_minimal_connections(self): |
|
"""Create minimal initial connections for a new network.""" |
|
bias_id = 12 |
|
output_start = bias_id + 1 |
|
|
|
|
|
for i in range(self.output_size): |
|
output_id = output_start + i |
|
|
|
|
|
self.connection_genes.append(ConnectionGene( |
|
bias_id, output_id, |
|
random.uniform(-1.0, 1.0), |
|
True |
|
)) |
|
|
|
|
|
input_id = random.randint(0, self.input_size - 1) |
|
self.connection_genes.append(ConnectionGene( |
|
input_id, output_id, |
|
random.uniform(-1.0, 1.0), |
|
True |
|
)) |
|
|
|
def _build_feed_forward_order(self): |
|
"""Build evaluation order ensuring feed-forward only topology.""" |
|
try: |
|
|
|
input_nodes = set(range(12)) |
|
bias_node = {12} |
|
output_nodes = {13, 14, 15} |
|
|
|
|
|
connections = {} |
|
for conn in self.connection_genes: |
|
if not conn.enabled: |
|
continue |
|
if conn.source not in connections: |
|
connections[conn.source] = [] |
|
connections[conn.source].append(conn.target) |
|
|
|
|
|
evaluated = input_nodes | bias_node |
|
eval_order = [] |
|
|
|
|
|
def can_evaluate(node_id): |
|
if node_id in connections: |
|
return all(dep in evaluated for dep in connections[node_id]) |
|
return True |
|
|
|
|
|
while True: |
|
ready_nodes = set() |
|
for node_id in self.node_genes: |
|
if node_id not in evaluated and can_evaluate(node_id): |
|
ready_nodes.add(node_id) |
|
|
|
if not ready_nodes: |
|
break |
|
|
|
|
|
for node_id in sorted(ready_nodes): |
|
incoming = [] |
|
incoming_weights = [] |
|
for conn in self.connection_genes: |
|
if conn.enabled and conn.target == node_id: |
|
incoming.append(conn.source) |
|
incoming_weights.append(conn.weight) |
|
|
|
if incoming: |
|
self.node_evals[node_id] = { |
|
'inputs': incoming, |
|
'weights': incoming_weights, |
|
'activation': self.node_genes[node_id].activation |
|
} |
|
eval_order.append(node_id) |
|
|
|
evaluated.add(node_id) |
|
|
|
|
|
for output_id in output_nodes: |
|
if output_id not in self.node_evals: |
|
print(f"Adding default evaluation for output {output_id}") |
|
|
|
self.node_evals[output_id] = { |
|
'inputs': [12], |
|
'weights': [1.0], |
|
'activation': 'sigmoid' |
|
} |
|
|
|
if not any(c.target == output_id and c.enabled for c in self.connection_genes): |
|
self.connection_genes.append( |
|
ConnectionGene(12, output_id, 1.0, True) |
|
) |
|
|
|
except Exception as e: |
|
print(f"Error in feed-forward build: {e}") |
|
|
|
self.node_evals = {} |
|
for i in range(3): |
|
output_id = 13 + i |
|
self.node_evals[output_id] = { |
|
'inputs': [12], |
|
'weights': [1.0], |
|
'activation': 'sigmoid' |
|
} |
|
|
|
def forward(self, inputs: jnp.ndarray) -> jnp.ndarray: |
|
"""Forward pass through the network.""" |
|
try: |
|
|
|
inputs = inputs[:8] |
|
|
|
|
|
original_shape = inputs.shape |
|
if len(inputs.shape) == 1: |
|
inputs = inputs.reshape(1, -1) |
|
batch_size = inputs.shape[0] |
|
|
|
|
|
max_node_id = max(node.id for node in self.node_genes.values()) |
|
|
|
|
|
activations = jnp.zeros((batch_size, max_node_id + 1)) |
|
|
|
|
|
for i in range(8): |
|
if i < len(inputs): |
|
activations = activations.at[:, i].set(inputs[:, i]) |
|
else: |
|
activations = activations.at[:, i].set(0.0) |
|
|
|
|
|
|
|
for i in range(8, 12): |
|
activations = activations.at[:, i].set(0.0) |
|
|
|
|
|
for node_id, eval_info in self.node_evals.items(): |
|
try: |
|
|
|
if node_id < 12: |
|
continue |
|
|
|
|
|
act = jnp.zeros(batch_size) |
|
for conn_source, conn_weight in zip(eval_info['inputs'], eval_info['weights']): |
|
act += activations[:, conn_source] * conn_weight |
|
|
|
|
|
if eval_info['activation'] == 'tanh': |
|
act = jnp.tanh(act) |
|
elif eval_info['activation'] == 'sigmoid': |
|
act = jax.nn.sigmoid(act) |
|
elif eval_info['activation'] == 'relu': |
|
act = jax.nn.relu(act) |
|
|
|
|
|
if node_id >= 20: |
|
act = jnp.where(act > 0.75, 1.0, 0.0) |
|
|
|
activations = activations.at[:, node_id].set(act) |
|
except Exception as e: |
|
print(f"Error at node {node_id}: {e}") |
|
|
|
|
|
output = activations[:, -3:] |
|
|
|
|
|
|
|
for i in range(8, 12): |
|
act = jnp.zeros(batch_size) |
|
for conn_source, conn_weight in zip(eval_info['inputs'], eval_info['weights']): |
|
if conn_source >= 20: |
|
act += activations[:, conn_source] * conn_weight |
|
activations = activations.at[:, i].set(jnp.tanh(act)) |
|
|
|
|
|
if len(original_shape) == 1: |
|
output = output.reshape(-1) |
|
|
|
return output |
|
except Exception as e: |
|
print(f"Error in forward pass: {e}") |
|
return jnp.zeros(3) |
|
|
|
def predict(self, inputs: jnp.ndarray) -> jnp.ndarray: |
|
"""Make a prediction for the given inputs. |
|
|
|
Args: |
|
inputs: Input array of shape (input_size,) or (batch_size, input_size) |
|
|
|
Returns: |
|
Predictions of shape (3,) for single input or (batch_size, 3) for batch |
|
""" |
|
outputs = self.forward(inputs) |
|
|
|
|
|
if len(outputs.shape) == 1: |
|
|
|
if outputs.shape[0] != 3: |
|
print(f"Adjusting output shape from {outputs.shape} to (3,)") |
|
return jnp.pad(outputs, (0, max(0, 3 - outputs.shape[0]))) |
|
return outputs |
|
else: |
|
|
|
if outputs.shape[1] != 3: |
|
print(f"Adjusting output shape from {outputs.shape} to (batch_size, 3)") |
|
return jnp.pad(outputs, ((0, 0), (0, max(0, 3 - outputs.shape[1])))) |
|
return outputs |
|
|
|
def clone(self) -> 'Network': |
|
"""Create a copy of this network with a cloned genome.""" |
|
return Network(self.genome.clone()) |
|
|
|
def mutate(self, config: Dict): |
|
"""Mutate the network's genome.""" |
|
self.genome.mutate(config) |
|
|
|
self._build_feed_forward_order() |
|
|
|
def to_genome(self) -> Genome: |
|
"""Convert network back to genome representation.""" |
|
genome = Genome(self.input_size, self.output_size) |
|
genome.node_genes = copy.deepcopy(self.node_genes) |
|
genome.connection_genes = copy.deepcopy(self.connection_genes) |
|
return genome |
|
|
|
class BaseNetwork: |
|
"""Base Network class for NEAT.""" |
|
|
|
def __init__(self, n_inputs: int, n_outputs: int): |
|
self.input_size = n_inputs |
|
self.output_size = n_outputs |
|
self.fitness = float('-inf') |
|
|
|
|
|
key = jax.random.PRNGKey(0) |
|
|
|
self.weights = jax.random.normal(key, (n_outputs, n_inputs)) * 0.5 |
|
|
|
self.bias = jnp.ones(n_outputs) * 0.1 |
|
|
|
def forward(self, x: jnp.ndarray) -> jnp.ndarray: |
|
"""Forward pass through the network.""" |
|
if x.ndim > 1: |
|
|
|
h = jnp.dot(x, self.weights.T) + self.bias[None, :] |
|
else: |
|
|
|
h = jnp.dot(x, self.weights.T) + self.bias |
|
return jnp.tanh(h) |
|
|
|
def get_params(self) -> Tuple[jnp.ndarray, jnp.ndarray]: |
|
"""Get network parameters.""" |
|
return self.weights, self.bias |
|
|
|
def set_params(self, params: Tuple[jnp.ndarray, jnp.ndarray]): |
|
"""Set network parameters.""" |
|
self.weights, self.bias = params |
|
|
|
def get_weights_numpy(self) -> np.ndarray: |
|
"""Get weights as numpy array for visualization.""" |
|
return np.array(self.weights) |
|
|
|
class NodeGene: |
|
"""Node gene containing node information.""" |
|
def __init__(self, node_id: int, node_type: str, activation: str = 'tanh'): |
|
"""Initialize node gene. |
|
|
|
Args: |
|
node_id: Node ID |
|
node_type: Type of node ('input', 'hidden', or 'output') |
|
activation: Activation function ('tanh', 'sigmoid', or 'relu') |
|
""" |
|
self.id = node_id |
|
self.type = node_type |
|
self.activation = activation |
|
|
|
if node_type in ['hidden', 'output']: |
|
key = jax.random.PRNGKey(node_id) |
|
self.bias = jax.random.normal(key, ()) * 0.5 |
|
else: |
|
self.bias = 0.0 |
|
|
|
class ConnectionGene: |
|
"""Gene representing a connection between nodes.""" |
|
def __init__(self, source: int, target: int, weight: float = None, enabled: bool = True): |
|
self.source = source |
|
self.target = target |
|
|
|
if weight is None: |
|
key = jax.random.PRNGKey(hash((source, target)) % 2**32) |
|
self.weight = jax.random.uniform(key, (), minval=-2.0, maxval=2.0) |
|
else: |
|
self.weight = weight |
|
self.enabled = enabled |
|
self.innovation = None |
|
|