NCERL-Diverse-PCG / src /drl /ac_agents.py
baiyanlali-zhao's picture
init
eaf2e33
from src.drl.ac_models import *
class ActCrtAgent:
def __init__(self, actor, critic, device='cpu'):
self.actor = actor
self.critic = critic
self.device = device
self.to(device)
def to(self, device):
self.actor.to(device)
self.critic.to(device)
self.device = device
@abstractmethod
def update(self, obs, acts, rews, ops):
pass
def make_decision(self, obs, **kwargs):
a, _ = self.actor.forward(
torch.tensor(obs, dtype=torch.float, device=self.device),
grad=False, **kwargs
)
return a.squeeze().cpu().numpy()
class SAC(ActCrtAgent):
def __init__(self, actor: SoftActor, critic: SoftDoubleClipCriticQ, device='cpu'):
super(SAC, self).__init__(actor, critic, device)
def update(self, obs, acts, rews, ops):
self.actor.zero_grads()
self.actor.backward_policy(self.critic, obs)
self.actor.backward_alpha(obs)
self.actor.grad_step()
self.critic.zero_grads()
self.critic.backward_mse(self.actor, obs, acts, rews, ops)
self.critic.grad_step()
self.critic.update_tarnet()
pass
class MESAC(ActCrtAgent):
def __init__(self, actor: MERegMixSoftActor, critic: SoftDoubleClipCriticQ, criticU: MERegDoubleClipCriticW, device='cpu'):
self.criticU = criticU
super(MESAC, self).__init__(actor, critic, device)
self.to(device)
def to(self, device):
self.actor.to(device)
self.critic.to(device)
self.criticU.to(device)
self.device = device
def update(self, obs, acts, rews, ops):
self.actor.zero_grads()
self.actor.backward_policy(self.critic, obs)
if self.actor.me_reg.lbd > 0.:
self.actor.backward_me_reg(self.criticU, obs)
self.actor.backward_alpha(obs)
self.actor.grad_step()
self.critic.zero_grads()
self.critic.backward_mse(self.actor, obs, acts, rews, ops)
self.critic.grad_step()
self.critic.update_tarnet()
if self.actor.me_reg.lbd > 0.:
self.criticU.zero_grads()
self.criticU.backward_mse(self.actor, obs, acts, rews, ops)
self.criticU.grad_step()
self.criticU.update_tarnet()
pass
if __name__ == '__main__':
pass