Spaces:
Runtime error
Runtime error
from typing import Any, Sequence | |
from chex import PRNGKey | |
import distrax | |
from flax import struct | |
import jax | |
import jax.numpy as jnp | |
class HybridAction: | |
discrete: int | |
continuous: jnp.ndarray | |
class HybridActionDistribution(distrax.Distribution): | |
def __init__(self, discrete_logits, continuous_mu, continuous_sigma) -> None: | |
self.discrete = distrax.Categorical(logits=discrete_logits) | |
self.continuous = distrax.MultivariateNormalDiag(continuous_mu, continuous_sigma) | |
def _sample_n(self, rng: PRNGKey, n: int) -> Any: | |
rng, _rng, _rng2 = jax.random.split(rng, 3) | |
a = self.discrete._sample_n(_rng, n) | |
b = self.continuous._sample_n(_rng2, n) | |
return HybridAction(a, b) | |
def log_prob(self, value: Any): | |
a = self.discrete.log_prob(value.discrete) | |
b = self.continuous.log_prob(value.continuous) | |
return a + b # log probs, we add. | |
def entropy(self): | |
return self.discrete.entropy() + self.continuous.entropy() | |
def event_shape(self) -> Sequence[int]: | |
return () | |
class MultiDiscreteActionDistribution(distrax.Distribution): | |
def __init__(self, flat_logits, number_of_dims_per_distribution) -> None: | |
self.distributions = [] | |
total_dims = 0 | |
for dims in number_of_dims_per_distribution: | |
self.distributions.append(distrax.Categorical(logits=flat_logits[..., total_dims : total_dims + dims])) | |
total_dims += dims | |
def _sample_n(self, key: PRNGKey, n: int) -> Any: | |
rngs = jax.random.split(key, len(self.distributions)) | |
samples = [jnp.expand_dims(d._sample_n(rng, n), axis=-1) for rng, d in zip(rngs, self.distributions)] | |
return jnp.concatenate(samples, axis=-1) | |
def log_prob(self, value: Any): | |
return sum(d.log_prob(value[..., i]) for i, d in enumerate(self.distributions)) | |
def entropy(self): | |
return sum(d.entropy() for d in self.distributions) | |
def event_shape(self) -> Sequence[int]: | |
return () | |