File size: 2,359 Bytes
eaf2e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from src.drl.ac_models import *


class ActCrtAgent:
    def __init__(self, actor, critic, device='cpu'):
        self.actor = actor
        self.critic = critic
        self.device = device
        self.to(device)

    def to(self, device):
        self.actor.to(device)
        self.critic.to(device)
        self.device = device

    @abstractmethod
    def update(self, obs, acts, rews, ops):
        pass

    def make_decision(self, obs, **kwargs):
        a, _ = self.actor.forward(
            torch.tensor(obs, dtype=torch.float, device=self.device),
            grad=False, **kwargs
        )
        return a.squeeze().cpu().numpy()


class SAC(ActCrtAgent):
    def __init__(self, actor: SoftActor, critic: SoftDoubleClipCriticQ, device='cpu'):
        super(SAC, self).__init__(actor, critic, device)

    def update(self, obs, acts, rews, ops):
        self.actor.zero_grads()
        self.actor.backward_policy(self.critic, obs)
        self.actor.backward_alpha(obs)
        self.actor.grad_step()
        self.critic.zero_grads()
        self.critic.backward_mse(self.actor, obs, acts, rews, ops)
        self.critic.grad_step()
        self.critic.update_tarnet()
        pass


class MESAC(ActCrtAgent):
    def __init__(self, actor: MERegMixSoftActor, critic: SoftDoubleClipCriticQ, criticU: MERegDoubleClipCriticW, device='cpu'):
        self.criticU = criticU
        super(MESAC, self).__init__(actor, critic, device)
        self.to(device)

    def to(self, device):
        self.actor.to(device)
        self.critic.to(device)
        self.criticU.to(device)
        self.device = device

    def update(self, obs, acts, rews, ops):
        self.actor.zero_grads()
        self.actor.backward_policy(self.critic, obs)
        if self.actor.me_reg.lbd > 0.:
            self.actor.backward_me_reg(self.criticU, obs)
        self.actor.backward_alpha(obs)
        self.actor.grad_step()
        self.critic.zero_grads()
        self.critic.backward_mse(self.actor, obs, acts, rews, ops)
        self.critic.grad_step()
        self.critic.update_tarnet()
        if self.actor.me_reg.lbd > 0.:
            self.criticU.zero_grads()
            self.criticU.backward_mse(self.actor, obs, acts, rews, ops)
            self.criticU.grad_step()
            self.criticU.update_tarnet()
        pass



if __name__ == '__main__':

    pass