Spaces:
Sleeping
Sleeping
File size: 517 Bytes
eaf2e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
"""
Torch argmax policy
"""
import numpy as np
from torch import nn
import rlkit.torch.pytorch_util as ptu
from rlkit.policies.base import Policy
class ArgmaxDiscretePolicy(nn.Module, Policy):
def __init__(self, qf):
super().__init__()
self.qf = qf
def get_action(self, obs):
obs = np.expand_dims(obs, axis=0)
obs = ptu.from_numpy(obs).float()
q_values = self.qf(obs).squeeze(0)
q_values_np = ptu.get_numpy(q_values)
return q_values_np.argmax(), {}
|