Spaces:
Sleeping
Sleeping
import torch | |
from src.drl.ac_models import SoftActor | |
from src.drl.nets import esmb_sample | |
class PMOESoftActor(SoftActor): | |
def __init__(self, net_constructor, tar_ent=None): | |
super(PMOESoftActor, self).__init__(net_constructor, tar_ent) | |
def forward(self, obs, grad=True, mono=True): | |
if grad: | |
return self.net(obs, mono) | |
with torch.no_grad(): | |
return self.net(obs, mono) | |
def backward_policy(self, critic, obs): | |
muss, stdss, betas = self.net.get_intermediate(obs) | |
actss, logpss, _ = esmb_sample(muss, stdss, betas, False) | |
obss = torch.unsqueeze(obs, dim=1).expand(-1, actss.shape[1], -1) | |
qvaluess = critic.forward(obss, actss) | |
l_pri = (torch.sum(self.alpha_coe(logpss, grad=False) - qvaluess, dim=-1)).mean() | |
t = qvaluess - torch.max(qvaluess, -1, True).values | |
v = torch.where(t == 0., 1., 0.) - betas | |
l_frep = (v * v).sum(-1).mean() | |
l = l_frep + l_pri | |
l.backward() | |
pass | |
def backward_alpha(self, obs): | |
_, logps = self.forward(obs, grad=False) | |
loss_alpha = -(self.alpha_coe(logps + self.tar_ent)).mean() | |
loss_alpha.backward() | |
pass | |