baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
1.21 kB
import torch
from src.drl.ac_models import SoftActor
from src.drl.nets import esmb_sample
class PMOESoftActor(SoftActor):
def __init__(self, net_constructor, tar_ent=None):
super(PMOESoftActor, self).__init__(net_constructor, tar_ent)
def forward(self, obs, grad=True, mono=True):
if grad:
return self.net(obs, mono)
with torch.no_grad():
return self.net(obs, mono)
def backward_policy(self, critic, obs):
muss, stdss, betas = self.net.get_intermediate(obs)
actss, logpss, _ = esmb_sample(muss, stdss, betas, False)
obss = torch.unsqueeze(obs, dim=1).expand(-1, actss.shape[1], -1)
qvaluess = critic.forward(obss, actss)
l_pri = (torch.sum(self.alpha_coe(logpss, grad=False) - qvaluess, dim=-1)).mean()
t = qvaluess - torch.max(qvaluess, -1, True).values
v = torch.where(t == 0., 1., 0.) - betas
l_frep = (v * v).sum(-1).mean()
l = l_frep + l_pri
l.backward()
pass
def backward_alpha(self, obs):
_, logps = self.forward(obs, grad=False)
loss_alpha = -(self.alpha_coe(logps + self.tar_ent)).mean()
loss_alpha.backward()
pass