import time import torch import random import numpy as np from math import ceil from src.drl.rep_mem import ReplayMem from src.env.environments import AsyncOlGenEnv from src.rlkit.samplers.rollout_functions import get_ucb_std from src.rlkit.torch.sac.neurips20_sac_ensemble import NeurIPS20SACEnsembleTrainer from src.utils.datastruct import RingQueue from src.utils.filesys import getpath class AsyncOffPolicyALgo: def __init__(self, rep_mem=None, update_per=1, batch=256, device='cpu'): self.rep_mem = rep_mem self.update_per = update_per self.batch = batch self.loggers = [] self.device = device self.start_time = 0. self.steps = 0 self.num_trans = 0 self.num_updates = 0 self.env = None self.trainer = None self.proxy_agent = None pass def train(self, env:AsyncOlGenEnv, trainer:NeurIPS20SACEnsembleTrainer, budget, inference_type, path): assert trainer.device == self.device self.__reset(env, trainer) env.reset() actors, critic1, critic2 = trainer.policy, trainer.qf1, trainer.qf2 for logger in self.loggers: if logger.__class__.__name__ == 'GenResLogger': logger.on_episode(env, self.proxy_agent, 0) while self.steps < budget: if inference_type > 0.: self.ucb_interaction(inference_type, critic1, critic2) else: self.rand_choice_iteraction() update_credits = ceil(1.25 * env.eplen / self.update_per) self.__update(update_credits) self.__update(0, close=True) for i, actor in enumerate(actors): torch.save(actor, getpath(path, f'policy{i}.pth')) def __update(self, model_credits, close=False): transitions, rewss = self.env.close() if close else self.env.rollout() self.rep_mem.add_transitions(transitions) self.num_trans += len(transitions) if len(self.rep_mem) > self.batch: t = 1 while self.num_trans >= (self.num_updates + 1) * self.update_per: # agent.update(*self.rep_mem.sample(self.batch)) self.trainer.train_from_torch(self.rep_mem.sample(self.batch)) self.num_updates += 1 if not close and t == model_credits: break t += 1 for logger in self.loggers: loginfo = self.__pack_loginfo(rewss) if logger.__class__.__name__ == 'GenResLogger': logger.on_episode(self.env, self.proxy_agent, self.steps) else: logger.on_episode(**loginfo, close=close) def __pack_loginfo(self, rewss): return { 'steps': self.steps, 'time': time.time() - self.start_time, 'rewss': rewss, 'trans': self.num_trans, 'updates': self.num_updates } def __reset(self, env:AsyncOlGenEnv, trainer:NeurIPS20SACEnsembleTrainer): self.start_time = time.time() self.steps = 0 self.num_trans = 0 self.num_updates = 0 self.env = env self.trainer = trainer self.proxy_agent = SunriseProxyAgent(trainer.policy, self.device) if self.rep_mem is None: self.rep_mem = MaskedRepMem(trainer.num_ensemble) assert self.rep_mem.m == trainer.num_ensemble def set_loggers(self, *loggers): self.loggers = loggers def ucb_interaction(self, inference_type, critic1, critic2, feedback_type=1): """ Adapted from rlkit.samplers.rollout_functions.ensemble_ucb_rollout. Mask generation is moved to replay memory (MaskedRepMem) Noise flag is ignored since it does not useful for our experiments. feedback_type is fixed to 1 as original code does not change it. """ o = self.env.getobs() policy = self.trainer.policy for subpolicy in policy: subpolicy.reset() while True: a_max, ucb_max, agent_info_max = None, None, None for i, subpolicy in enumerate(policy): _a, agent_info = subpolicy.get_action(o) ucb_score = get_ucb_std( o, _a, inference_type, critic1, critic2, feedback_type, i, len(policy) ) if i == 0: a_max = _a ucb_max = ucb_score else: if ucb_score > ucb_max: ucb_max = ucb_score a_max = _a # print(a_max) o, d = self.env.step(a_max.squeeze()) self.steps += 1 if d: break pass def rand_choice_iteraction(self): o = self.env.getobs() policy = self.trainer.policy for subpolicy in policy: subpolicy.reset() choiced_policy = random.choice(policy) while True: _a, _ = choiced_policy.get_action(o) o, d = self.env.step(_a.squeeze()) self.steps += 1 if d: break pass class SunriseProxyAgent: # Refer to rlkit.samplers.rollout_functions.ensemble_eval_rollout def __init__(self, actors, device): self.actors = actors self.device = device def make_decision(self, obs): o = torch.tensor(obs, device=self.device, dtype=torch.float32) actions = [] with torch.no_grad(): for m in self.actors: a = torch.clamp(m(o)[0], -1, 1) actions.append(a.cpu().numpy()) actions = np.array(actions) selections = [random.choice(range(len(self.actors))) for _ in range(len(obs))] selected = [actions[s, i, :] for i, s in enumerate(selections)] return np.array(selected) # if len(obs.shape) == 1: # obs = obs.reshape(1, -1) # with torch.no_grad(): # actions = np.stack([actor.get_action(obs)[0].squeeze() for actor in self.actors]) # return actions.mean(axis=0) def reset(self): for actor in self.actors: actor.reset() class MaskedRepMem: def __init__(self, num_ensemble, capacity=500000, ber_mean=0.0, device='cpu'): self.base = ReplayMem(capacity, device) self.mask_queue = RingQueue(capacity) self.ber_mean = ber_mean self.device = device self.m = num_ensemble def add(self, o, a, r, op): mask = torch.bernoulli(torch.Tensor([self.ber_mean] * self.m)) if mask.sum() == 0: rand_index = np.random.randint(self.m, size=1) mask[rand_index] = 1 self.mask_queue.push(mask.to(self.base.device)) self.base.add(o, a, r, op) pass def add_transitions(self, trainsitions): for t in trainsitions: self.add(*t) def __len__(self): return len(self.base) def sample(self, n): indexes = random.sample(range(len(self.base)), n) base_mem = self.base.queue.main mask_mem = self.mask_queue.main obs, acts, rews, ops, masks = [], [], [], [], [] for i in indexes: o, a, r, op = base_mem[i] obs.append(o) acts.append(a) rews.append([r]) ops.append(op) masks.append(mask_mem[i]) return { 'observations': torch.stack(obs), 'actions': torch.stack(acts), 'rewards': torch.tensor(rews, device=self.device, dtype=torch.float), 'next_observations': torch.stack(ops), 'masks': torch.stack(masks) } pass