NCERL-Diverse-PCG / src /olgen /olg_policy.py
baiyanlali-zhao's picture
添加注释
3582c8a
import glob
import random
from abc import abstractmethod
import numpy as np
import torch
from src.utils.filesys import getpath
from src.gan.gans import nz
from src.gan.gankits import sample_latvec
from src.drl.drl_uses import load_cfgs
from src.drl.sunrise.sunrise_adaption import SunriseProxyAgent
def process_obs(obs, device='cpu'):
obs = torch.tensor(obs, device=device, dtype=torch.float32)
if len(obs.shape) == 1:
obs = obs.unsqueeze(0)
return obs
class GenPolicy:
def __init__(self, n=5):
self.n = n # Number of segments in an observation
@abstractmethod
def step(self, obs):
pass
@staticmethod
@abstractmethod
def from_path(path, **kwargs):
pass
def reset(self):
pass
class RLGenPolicy(GenPolicy):
def __init__(self, model, n, device='cuda:0'):
self.model = model
super(RLGenPolicy, self).__init__(n)
self.model.eval()
self.model.to(device)
self.device = device
self.meregs = []
def step(self, obs):
obs = process_obs(obs, device=self.device)
b, d = obs.shape
if d < nz * self.n:
obs = torch.cat([torch.zeros([b, nz * self.n - d], device=self.device), obs], dim=-1)
with torch.no_grad():
model_output, _ = self.model(obs)
return torch.clamp(model_output, -1, 1).squeeze().cpu().numpy()
@staticmethod
def from_path(path, device='cuda:0'):
model = torch.load(getpath(f'{path}/policy.pth'), map_location=device)
n = load_cfgs(path, 'N')
return RLGenPolicy(model, n, device)
class EnsembleGenPolicy(GenPolicy):
def __init__(self, models, n, device='cpu'):
super(EnsembleGenPolicy, self).__init__(n)
for model in models:
model.to(device)
self.device = device
self.models = models
self.m = len(models)
def step(self, obs):
o = torch.tensor(obs, device=self.device, dtype=torch.float32)
actions = []
with torch.no_grad():
for m in self.models:
a = m(o) # action model predict
if type(a) == tuple:
a = a[0]
actions.append(torch.clamp(a, -1, 1).cpu().numpy())
if len(obs.shape) == 1:
return random.choice(actions)
else:
# 这里对于每个observation, 选择m个模型, 每个模型都输出一个动作, 然后随机选择其中一个动作
actions = np.array(actions)
# 这里的self.m就是模型的数量, 等价于len(self.models)
selections = [random.choice(range(self.m)) for _ in range(len(obs))]
selected = [actions[s, i, :] for i, s in enumerate(selections)]
return np.array(selected)
@staticmethod
def from_path(path, device='cpu'):
"""
读取path中的所有模型
"""
models = [
torch.load(p, map_location=device)
for p in glob.glob(getpath(path, 'policy*.pth'))
]
n = load_cfgs(path, 'N')
return EnsembleGenPolicy(models, n, device)
class RandGenPolicy(GenPolicy):
def __init__(self):
super(RandGenPolicy, self).__init__(1)
def step(self, obs):
n = obs.shape[0]
return sample_latvec(n).squeeze().numpy()
@staticmethod
def from_path(path, **kwargs):
return RandGenPolicy()
class DvDGenPolicy(GenPolicy):
def __init__(self, learner, n, rand_switch=False):
super(DvDGenPolicy, self).__init__(n)
self.master = learner
self.rand_switch = rand_switch
self.working_policy = None
self.reset()
def step(self, obs):
return self.working_policy.forward(obs).astype(np.float32)
def reset(self):
if self.rand_switch:
self.working_policy = random.choice(self.master.agents)
else:
self.working_policy = self.master.agents[self.master.agent]
@staticmethod
def from_path(path, device='cpu'):
"""
We don't find loading function in DvD-ES codes and we have no idea about
how to implement it :-(
"""
return None