Spaces:
Sleeping
Sleeping
""" | |
Torch argmax policy | |
""" | |
import numpy as np | |
from torch import nn | |
import rlkit.torch.pytorch_util as ptu | |
from rlkit.policies.base import Policy | |
class ArgmaxDiscretePolicy(nn.Module, Policy): | |
def __init__(self, qf): | |
super().__init__() | |
self.qf = qf | |
def get_action(self, obs): | |
obs = np.expand_dims(obs, axis=0) | |
obs = ptu.from_numpy(obs).float() | |
q_values = self.qf(obs).squeeze(0) | |
q_values_np = ptu.get_numpy(q_values) | |
return q_values_np.argmax(), {} | |