import time import torch from math import ceil from src.drl.rep_mem import ReplayMem from src.utils.filesys import getpath from src.env.environments import AsyncOlGenEnv, SingleProcessOLGenEnv class AsyncOffpolicyTrainer: def __init__(self, rep_mem:ReplayMem=None, update_per=1, batch=256): self.rep_mem = ReplayMem() if rep_mem is None else rep_mem self.update_per = update_per self.batch = batch self.loggers = [] self.start_time = 0. self.steps = 0 self.num_trans = 0 self.num_updates = 0 pass def train(self, env:AsyncOlGenEnv, agent, budget, path, check_points=None): if check_points is None: check_points = [] check_points.sort(reverse=True) self._reset() o = env.reset() for logger in self.loggers: if logger.__class__.__name__ == 'GenResLogger': logger.on_episode(env, agent, 0) while self.steps < budget: agent.actor.eval() a = agent.make_decision(o) o, done = env.step(a) self.steps += 1 if done: model_credits = ceil(1.25 * env.eplen / self.update_per) # agent.actor.train() self._update(model_credits, env, agent) if len(check_points) and self.steps >= check_points[-1]: torch.save(agent.actor.net, getpath(f'{path}/policy{self.steps}.pth')) check_points.pop() self._update(0, env, agent, close=True) torch.save(agent.actor.net, getpath(f'{path}/policy.pth')) def _update(self, model_credits, env, agent, close=False): transitions, rewss = env.close() if close else env.rollout() self.rep_mem.add_transitions(transitions) self.num_trans += len(transitions) if len(self.rep_mem) > self.batch: t = 1 while self.num_trans >= (self.num_updates + 1) * self.update_per: agent.update(*self.rep_mem.sample(self.batch)) self.num_updates += 1 if not close and t == model_credits: break t += 1 for logger in self.loggers: loginfo = self._pack_loginfo(rewss) if logger.__class__.__name__ == 'GenResLogger': logger.on_episode(env, agent, self.steps) else: logger.on_episode(**loginfo, close=close) def _pack_loginfo(self, rewss): return { 'steps': self.steps, 'time': time.time() - self.start_time, 'rewss': rewss, 'trans': self.num_trans, 'updates': self.num_updates } def _reset(self): self.start_time = time.time() self.steps = 0 self.num_trans = 0 self.num_updates = 0 def set_loggers(self, *loggers): self.loggers = loggers class SinProcOffpolicyTrainer: def __init__(self, rep_mem:ReplayMem=None, update_per=2, batch=256): self.rep_mem = ReplayMem() if rep_mem is None else rep_mem self.update_per = update_per self.batch = batch self.loggers = [] self.start_time = 0. self.steps = 0 self.num_trans = 0 self.num_updates = 0 def train(self, env:SingleProcessOLGenEnv, agent, budget, path): self.__reset() o = env.reset() for logger in self.loggers: if logger.__class__.__name__ == 'GenResLogger': logger.on_episode(env, agent, 0) while self.steps < budget: agent.actor.eval() a = agent.make_decision(o) o, _, done, info = env.step(a) self.steps += 1 if done: self.__update(env, agent, info) self.__update(env, agent, {'transitions': [], 'rewss': []}, True) torch.save(agent.actor.net, getpath(f'{path}/policy.pth')) def __update(self, env, agent, info, close=False): transitions, rewss = info['transitions'], info['rewss'] # print(rewss) self.rep_mem.add_transitions(transitions) self.num_trans += len(transitions) if len(self.rep_mem) > self.batch: t = 1 while self.num_trans >= (self.num_updates + 1) * self.update_per: agent.update(*self.rep_mem.sample(self.batch)) self.num_updates += 1 t += 1 for logger in self.loggers: loginfo = self.__pack_loginfo(rewss) if logger.__class__.__name__ == 'GenResLogger': logger.on_episode(env, agent, self.steps) else: logger.on_episode(**loginfo, close=close) def __pack_loginfo(self, rewss): if len(rewss): rewss = [rewss] return { 'steps': self.steps, 'time': time.time() - self.start_time, 'rewss': rewss, 'trans': self.num_trans, 'updates': self.num_updates } def __reset(self): self.start_time = time.time() self.steps = 0 self.num_trans = 0 self.num_updates = 0 def set_loggers(self, *loggers): self.loggers = loggers