File size: 2,264 Bytes
ca85408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b471ab8
 
ca85408
 
 
934779e
ca85408
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from typing import Dict, Tuple, Optional, Callable, Union
import gymnasium as gym
from kan import KAN
import numpy as np


def extract_dim(space: gym.Space):
    if isinstance(space, gym.spaces.Box) and len(space.shape) == 1:
        return space.shape[0], False
    elif isinstance(space, gym.spaces.Discrete):
        return space.n, True
    else:
        raise NotImplementedError(f"There is no support for space {space}.")


class InterpretablePolicyExtractor:
    lib = ['x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs']

    def __init__(self, env_name: str, hidden_widths: Optional[Tuple[int]]=None):
        self.env = gym.make(env_name)
        if hidden_widths is None:
            hidden_widths = []
        observation_dim, self._observation_is_discrete = extract_dim(self.env.observation_space)
        action_dim, self._action_is_discrete = extract_dim(self.env.action_space)
        self.policy = KAN(width=[observation_dim, *hidden_widths, action_dim])
        self.loss_fn = torch.nn.MSELoss() if not self._action_is_discrete else torch.nn.CrossEntropyLoss()

    def train_from_dataset(self, dataset: Union[Dict[str, torch.Tensor], str], steps: int = 20):
        if isinstance(dataset, str):
            dataset = torch.load(dataset)
        if dataset["train_label"].ndim == 1 and not self._action_is_discrete:
            dataset["train_label"] = dataset["train_label"][:, None]
        if dataset["train_label"].ndim == 1 and not self._action_is_discrete:
            dataset["test_label"] = dataset["test_label"][:, None]
        dataset["train_input"] = dataset["train_input"].float()
        dataset["test_input"] = dataset["test_input"].float()
        return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn)

    def forward(self, observation):
        observation = torch.from_numpy(observation).float()
        action = self.policy(observation.unsqueeze(0))
        if self._action_is_discrete:
            return action.argmax(axis=-1).squeeze().item()
        else:
            return action.squeeze(0).detach().numpy()

    def train_from_policy(self, policy: Callable[[np.ndarray], Union[np.ndarray, int, float]], steps: int):
        raise NotImplementedError()  # TODO