Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
from typing import Dict, Tuple, Optional, Callable, Union | |
import gymnasium as gym | |
from kan import KAN | |
import numpy as np | |
def extract_dim(space: gym.Space): | |
if isinstance(space, gym.spaces.Box) and len(space.shape) == 1: | |
return space.shape[0], False | |
elif isinstance(space, gym.spaces.Discrete): | |
return space.n, True | |
else: | |
raise NotImplementedError(f"There is no support for space {space}.") | |
class InterpretablePolicyExtractor: | |
lib = ['x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs'] | |
def __init__(self, env_name: str, hidden_widths: Optional[Tuple[int]]=None): | |
self.env = gym.make(env_name) | |
if hidden_widths is None: | |
hidden_widths = [] | |
observation_dim, self._observation_is_discrete = extract_dim(self.env.observation_space) | |
action_dim, self._action_is_discrete = extract_dim(self.env.action_space) | |
self.policy = KAN(width=[observation_dim, *hidden_widths, action_dim]) | |
self.loss_fn = torch.nn.MSELoss() if not self._action_is_discrete else torch.nn.CrossEntropyLoss() | |
def train_from_dataset(self, dataset: Union[Dict[str, torch.Tensor], str], steps: int = 20): | |
if isinstance(dataset, str): | |
dataset = torch.load(dataset) | |
if dataset["train_label"].ndim == 1 and not self._action_is_discrete: | |
dataset["train_label"] = dataset["train_label"][:, None] | |
if dataset["train_label"].ndim == 1 and not self._action_is_discrete: | |
dataset["test_label"] = dataset["test_label"][:, None] | |
dataset["train_input"] = dataset["train_input"].float() | |
dataset["test_input"] = dataset["test_input"].float() | |
return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn) | |
def forward(self, observation): | |
observation = torch.from_numpy(observation).float() | |
action = self.policy(observation.unsqueeze(0)) | |
if self._action_is_discrete: | |
return action.argmax(axis=-1).squeeze().item() | |
else: | |
return action.squeeze(0).detach().numpy() | |
def train_from_policy(self, policy: Callable[[np.ndarray], Union[np.ndarray, int, float]], steps: int): | |
raise NotImplementedError() # TODO | |