Spaces:
Running
Running
File size: 4,550 Bytes
be5548b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import numpy as np
import utils
import os
import pickle
import torch
class AgentWrap:
""" Handles action selection without gradient updates for proper testing """
def __init__(self, acmodel, preprocess_obss, device, num_envs=1, argmax=False):
self.preprocess_obss = preprocess_obss
self.acmodel = acmodel
self.device = device
self.argmax = argmax
self.num_envs = num_envs
if self.acmodel.recurrent:
self.memories = torch.zeros(self.num_envs, self.acmodel.memory_size, device=self.device)
def get_actions(self, obss):
preprocessed_obss = self.preprocess_obss(obss, device=self.device)
with torch.no_grad():
if self.acmodel.recurrent:
dist, _, self.memories = self.acmodel(preprocessed_obss, self.memories)
else:
dist, _ = self.acmodel(preprocessed_obss)
if isinstance(dist, torch.distributions.Distribution):
if self.argmax:
actions = dist.probs.max(1, keepdim=True)[1]
else:
actions = dist.sample()
else:
if self.argmax:
actions = torch.stack([d.probs.max(1)[1] for d in dist], dim=1)
else:
actions = torch.stack([d.sample() for d in dist], dim=1)
return self.acmodel.construct_final_action(actions.cpu().numpy())
def get_action(self, obs):
return self.get_actions([obs])[0]
def analyze_feedbacks(self, rewards, dones):
if self.acmodel.recurrent:
masks = 1 - torch.tensor(dones, dtype=torch.float, device=self.device).unsqueeze(1)
self.memories *= masks
def analyze_feedback(self, reward, done):
return self.analyze_feedbacks([reward], [done])
class Tester:
def __init__(self, env_args, seed, episodes, save_path, acmodel, preprocess_obss, device):
self.envs = [utils.make_env(
**env_args
) for _ in range(episodes)]
self.seed = seed
self.episodes = episodes
self.ep_counter = 0
self.savefile = save_path + "/testing_{}.pkl".format(self.envs[0].spec.id)
print("Testing log: ", self.savefile)
self.stats_dict = {"test_rewards": [], "test_success_rates": [], "test_step_nb": []}
self.agent = AgentWrap(acmodel, preprocess_obss, device)
def test_agent(self, num_frames):
self.agent.acmodel.eval()
# set seed
# self.env.seed(self.seed)
# save test time (nb training steps)
self.stats_dict['test_step_nb'].append(num_frames)
rewards = []
success_rates = []
# cols = []
# s = "-".join([e.current_env.marble.color for e in self.envs])
# print("s:", s)
for episode in range(self.episodes):
# self.envs[episode].seed(self.seed)
self.envs[episode].seed(episode)
# print("current_seed", np.random.get_state()[1][0])
obs = self.envs[episode].reset()
# cols.append(self.envs[episode].current_env.marble.color)
# cols.append(str(self.envs[episode].current_env.marble.cur_pos))
done = False
while not done:
action = self.agent.get_action(obs)
obs, reward, done, info = self.envs[episode].step(action)
self.agent.analyze_feedback(reward, done)
if done:
rewards.append(reward)
success_rates.append(info['success'])
break
# from hashlib import md5
# hash_string = "-".join(cols).encode()
# print('hs:', hash_string[:20])
# print("hash test envs:", md5(hash_string).hexdigest())
mean_rewards = np.array(rewards).mean()
mean_success_rates = np.array(success_rates).mean()
self.stats_dict["test_rewards"].append(mean_rewards)
self.stats_dict["test_success_rates"].append(mean_success_rates)
self.agent.acmodel.train()
return mean_success_rates, mean_rewards
def load(self):
if os.path.isfile(self.savefile):
with open(self.savefile, 'rb') as f:
stats_dict_loaded = pickle.load(f)
for k, v in stats_dict_loaded.items():
self.stats_dict[k] = v
else:
raise ValueError(f"Save file {self.savefile} doesn't exist.")
def dump(self):
with open(self.savefile, 'wb') as f:
pickle.dump(self.stats_dict, f)
|