Spaces:
Sleeping
Sleeping
File size: 4,296 Bytes
eaf2e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import numpy as np
import torch
from torch import nn as nn
from src.rlkit.policies.base import ExplorationPolicy, Policy
from src.rlkit.torch.core import eval_np
from src.rlkit.torch.distributions import TanhNormal
from src.rlkit.torch.networks import Mlp
LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
class TanhGaussianPolicy(Mlp, ExplorationPolicy):
"""
Usage:
```
policy = TanhGaussianPolicy(...)
action, mean, log_std, _ = policy(obs)
action, mean, log_std, _ = policy(obs, deterministic=True)
action, mean, log_std, log_prob = policy(obs, return_log_prob=True)
```
Here, mean and log_std are the mean and log_std of the Gaussian that is
sampled from.
If deterministic is True, action = tanh(mean).
If return_log_prob is False (default), log_prob = None
This is done because computing the log_prob can be a bit expensive.
"""
def __init__(
self,
hidden_sizes,
obs_dim,
action_dim,
std=None,
init_w=1e-3,
**kwargs
):
super().__init__(
hidden_sizes,
input_size=obs_dim,
output_size=action_dim,
init_w=init_w,
**kwargs
)
self.log_std = None
self.std = std
if std is None:
last_hidden_size = obs_dim
if len(hidden_sizes) > 0:
last_hidden_size = hidden_sizes[-1]
self.last_fc_log_std = nn.Linear(last_hidden_size, action_dim)
# self.last_fc_log_std.weight.data.uniform_(-init_w, init_w)
# self.last_fc_log_std.bias.data.uniform_(-init_w, init_w)
else:
self.log_std = np.log(std)
assert LOG_SIG_MIN <= self.log_std <= LOG_SIG_MAX
def get_action(self, obs_np, deterministic=False):
actions = self.get_actions(obs_np[None], deterministic=deterministic)
return actions[0, :], {}
def get_actions(self, obs_np, deterministic=False):
return eval_np(self, obs_np, deterministic=deterministic)[0]
def forward(
self,
obs,
reparameterize=True,
deterministic=False,
return_log_prob=False,
):
"""
:param obs: Observation
:param deterministic: If True, do not sample
:param return_log_prob: If True, return a sample and its log probability
"""
h = obs
for i, fc in enumerate(self.fcs):
h = self.hidden_activation(fc(h))
mean = self.last_fc(h)
if self.std is None:
log_std = self.last_fc_log_std(h)
log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
std = torch.exp(log_std)
else:
std = self.std
log_std = self.log_std
log_prob = None
entropy = None
mean_action_log_prob = None
pre_tanh_value = None
if deterministic:
action = torch.tanh(mean)
else:
tanh_normal = TanhNormal(mean, std)
if return_log_prob:
if reparameterize is True:
action, pre_tanh_value = tanh_normal.rsample(
return_pretanh_value=True
)
else:
action, pre_tanh_value = tanh_normal.sample(
return_pretanh_value=True
)
log_prob = tanh_normal.log_prob(
action,
pre_tanh_value=pre_tanh_value
)
log_prob = log_prob.sum(dim=1, keepdim=True)
else:
if reparameterize is True:
action = tanh_normal.rsample()
else:
action = tanh_normal.sample()
return (
action, mean, log_std, log_prob, entropy, std,
mean_action_log_prob, pre_tanh_value,
)
class MakeDeterministic(nn.Module, Policy):
def __init__(self, stochastic_policy):
super().__init__()
self.stochastic_policy = stochastic_policy
def get_action(self, observation):
return self.stochastic_policy.get_action(observation,
deterministic=True)
|