File size: 4,550 Bytes
be5548b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import numpy as np
import utils
import os
import pickle
import torch

class AgentWrap:
    """ Handles action selection without gradient updates for proper testing """

    def __init__(self, acmodel, preprocess_obss, device, num_envs=1, argmax=False):

        self.preprocess_obss = preprocess_obss
        self.acmodel = acmodel

        self.device = device
        self.argmax = argmax
        self.num_envs = num_envs

        if self.acmodel.recurrent:
            self.memories = torch.zeros(self.num_envs, self.acmodel.memory_size, device=self.device)

    def get_actions(self, obss):
        preprocessed_obss = self.preprocess_obss(obss, device=self.device)

        with torch.no_grad():
            if self.acmodel.recurrent:
                dist, _, self.memories = self.acmodel(preprocessed_obss, self.memories)
            else:
                dist, _ = self.acmodel(preprocessed_obss)

        if isinstance(dist, torch.distributions.Distribution):
            if self.argmax:
                actions = dist.probs.max(1, keepdim=True)[1]
            else:
                actions = dist.sample()
        else:
            if self.argmax:
                actions = torch.stack([d.probs.max(1)[1] for d in dist], dim=1)
            else:
                actions = torch.stack([d.sample() for d in dist], dim=1)
        return self.acmodel.construct_final_action(actions.cpu().numpy())

    def get_action(self, obs):
        return self.get_actions([obs])[0]

    def analyze_feedbacks(self, rewards, dones):
        if self.acmodel.recurrent:
            masks = 1 - torch.tensor(dones, dtype=torch.float, device=self.device).unsqueeze(1)
            self.memories *= masks

    def analyze_feedback(self, reward, done):
        return self.analyze_feedbacks([reward], [done])


class Tester:

    def __init__(self, env_args, seed, episodes, save_path, acmodel, preprocess_obss, device):

        self.envs = [utils.make_env(
            **env_args
        ) for _ in range(episodes)]
        self.seed = seed
        self.episodes = episodes
        self.ep_counter = 0
        self.savefile = save_path + "/testing_{}.pkl".format(self.envs[0].spec.id)
        print("Testing log: ", self.savefile)

        self.stats_dict = {"test_rewards": [], "test_success_rates": [], "test_step_nb": []}
        self.agent = AgentWrap(acmodel, preprocess_obss, device)

    def test_agent(self, num_frames):
        self.agent.acmodel.eval()

        # set seed
        # self.env.seed(self.seed)
        # save test time (nb training steps)
        self.stats_dict['test_step_nb'].append(num_frames)

        rewards = []
        success_rates = []


        # cols = []
        # s = "-".join([e.current_env.marble.color for e in self.envs])
        # print("s:", s)

        for episode in range(self.episodes):
            # self.envs[episode].seed(self.seed)
            self.envs[episode].seed(episode)
            # print("current_seed", np.random.get_state()[1][0])

            obs = self.envs[episode].reset()

            # cols.append(self.envs[episode].current_env.marble.color)
            # cols.append(str(self.envs[episode].current_env.marble.cur_pos))

            done = False
            while not done:
                action = self.agent.get_action(obs)

                obs, reward, done, info = self.envs[episode].step(action)
                self.agent.analyze_feedback(reward, done)

                if done:
                    rewards.append(reward)
                    success_rates.append(info['success'])
                    break

        # from hashlib import md5
        # hash_string = "-".join(cols).encode()

        # print('hs:', hash_string[:20])
        # print("hash test envs:", md5(hash_string).hexdigest())

        mean_rewards = np.array(rewards).mean()
        mean_success_rates = np.array(success_rates).mean()

        self.stats_dict["test_rewards"].append(mean_rewards)
        self.stats_dict["test_success_rates"].append(mean_success_rates)

        self.agent.acmodel.train()
        return mean_success_rates, mean_rewards

    def load(self):
        if os.path.isfile(self.savefile):
            with open(self.savefile, 'rb') as f:
                stats_dict_loaded = pickle.load(f)

                for k, v in stats_dict_loaded.items():
                    self.stats_dict[k] = v
        else:
            raise ValueError(f"Save file {self.savefile} doesn't exist.")

    def dump(self):
        with open(self.savefile, 'wb') as f:
            pickle.dump(self.stats_dict, f)