Spaces:
Sleeping
Sleeping
File size: 2,359 Bytes
eaf2e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from src.drl.ac_models import *
class ActCrtAgent:
def __init__(self, actor, critic, device='cpu'):
self.actor = actor
self.critic = critic
self.device = device
self.to(device)
def to(self, device):
self.actor.to(device)
self.critic.to(device)
self.device = device
@abstractmethod
def update(self, obs, acts, rews, ops):
pass
def make_decision(self, obs, **kwargs):
a, _ = self.actor.forward(
torch.tensor(obs, dtype=torch.float, device=self.device),
grad=False, **kwargs
)
return a.squeeze().cpu().numpy()
class SAC(ActCrtAgent):
def __init__(self, actor: SoftActor, critic: SoftDoubleClipCriticQ, device='cpu'):
super(SAC, self).__init__(actor, critic, device)
def update(self, obs, acts, rews, ops):
self.actor.zero_grads()
self.actor.backward_policy(self.critic, obs)
self.actor.backward_alpha(obs)
self.actor.grad_step()
self.critic.zero_grads()
self.critic.backward_mse(self.actor, obs, acts, rews, ops)
self.critic.grad_step()
self.critic.update_tarnet()
pass
class MESAC(ActCrtAgent):
def __init__(self, actor: MERegMixSoftActor, critic: SoftDoubleClipCriticQ, criticU: MERegDoubleClipCriticW, device='cpu'):
self.criticU = criticU
super(MESAC, self).__init__(actor, critic, device)
self.to(device)
def to(self, device):
self.actor.to(device)
self.critic.to(device)
self.criticU.to(device)
self.device = device
def update(self, obs, acts, rews, ops):
self.actor.zero_grads()
self.actor.backward_policy(self.critic, obs)
if self.actor.me_reg.lbd > 0.:
self.actor.backward_me_reg(self.criticU, obs)
self.actor.backward_alpha(obs)
self.actor.grad_step()
self.critic.zero_grads()
self.critic.backward_mse(self.actor, obs, acts, rews, ops)
self.critic.grad_step()
self.critic.update_tarnet()
if self.actor.me_reg.lbd > 0.:
self.criticU.zero_grads()
self.criticU.backward_mse(self.actor, obs, acts, rews, ops)
self.criticU.grad_step()
self.criticU.update_tarnet()
pass
if __name__ == '__main__':
pass
|