kinet-test / kinetix /models /action_spaces.py
tree3po's picture
Upload 46 files
581eeac verified
raw
history blame
2.03 kB
from typing import Any, Sequence
from chex import PRNGKey
import distrax
from flax import struct
import jax
import jax.numpy as jnp
@struct.dataclass
class HybridAction:
discrete: int
continuous: jnp.ndarray
class HybridActionDistribution(distrax.Distribution):
def __init__(self, discrete_logits, continuous_mu, continuous_sigma) -> None:
self.discrete = distrax.Categorical(logits=discrete_logits)
self.continuous = distrax.MultivariateNormalDiag(continuous_mu, continuous_sigma)
def _sample_n(self, rng: PRNGKey, n: int) -> Any:
rng, _rng, _rng2 = jax.random.split(rng, 3)
a = self.discrete._sample_n(_rng, n)
b = self.continuous._sample_n(_rng2, n)
return HybridAction(a, b)
def log_prob(self, value: Any):
a = self.discrete.log_prob(value.discrete)
b = self.continuous.log_prob(value.continuous)
return a + b # log probs, we add.
def entropy(self):
return self.discrete.entropy() + self.continuous.entropy()
def event_shape(self) -> Sequence[int]:
return ()
class MultiDiscreteActionDistribution(distrax.Distribution):
def __init__(self, flat_logits, number_of_dims_per_distribution) -> None:
self.distributions = []
total_dims = 0
for dims in number_of_dims_per_distribution:
self.distributions.append(distrax.Categorical(logits=flat_logits[..., total_dims : total_dims + dims]))
total_dims += dims
def _sample_n(self, key: PRNGKey, n: int) -> Any:
rngs = jax.random.split(key, len(self.distributions))
samples = [jnp.expand_dims(d._sample_n(rng, n), axis=-1) for rng, d in zip(rngs, self.distributions)]
return jnp.concatenate(samples, axis=-1)
def log_prob(self, value: Any):
return sum(d.log_prob(value[..., i]) for i, d in enumerate(self.distributions))
def entropy(self):
return sum(d.entropy() for d in self.distributions)
def event_shape(self) -> Sequence[int]:
return ()