baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
11.5 kB
from collections import deque, OrderedDict
from src.rlkit.core.eval_util import create_stats_ordered_dict
from src.rlkit.samplers.rollout_functions import rollout, multitask_rollout, ensemble_rollout, ensemble_eval_rollout
from src.rlkit.samplers.rollout_functions import ensemble_ucb_rollout
from src.rlkit.samplers.data_collector.base import PathCollector
class MdpPathCollector(PathCollector):
def __init__(
self,
env,
policy,
noise_flag=0,
max_num_epoch_paths_saved=None,
render=False,
render_kwargs=None,
):
if render_kwargs is None:
render_kwargs = {}
self._env = env
self._policy = policy
self._max_num_epoch_paths_saved = max_num_epoch_paths_saved
self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
self._render = render
self._render_kwargs = render_kwargs
self._noise_flag = noise_flag
self._num_steps_total = 0
self._num_paths_total = 0
def collect_new_paths(
self,
max_path_length,
num_steps,
discard_incomplete_paths,
):
paths = []
num_steps_collected = 0
while num_steps_collected < num_steps:
max_path_length_this_loop = min( # Do not go over num_steps
max_path_length,
num_steps - num_steps_collected,
)
path = rollout(
self._env,
self._policy,
noise_flag=self._noise_flag,
max_path_length=max_path_length_this_loop,
)
path_len = len(path['actions'])
if (
path_len != max_path_length
and not path['terminals'][-1]
and discard_incomplete_paths
):
break
num_steps_collected += path_len
paths.append(path)
self._num_paths_total += len(paths)
self._num_steps_total += num_steps_collected
self._epoch_paths.extend(paths)
return paths
def collect_normalized_new_paths(
self,
max_path_length,
num_steps,
discard_incomplete_paths,
input_mean,
input_std,
):
paths = []
num_steps_collected = 0
while num_steps_collected < num_steps:
max_path_length_this_loop = min( # Do not go over num_steps
max_path_length,
num_steps - num_steps_collected,
)
path = normalized_rollout(
self._env,
self._policy,
noise_flag=self._noise_flag,
max_path_length=max_path_length_this_loop,
input_mean=input_mean,
input_std=input_std,
)
path_len = len(path['actions'])
if (
path_len != max_path_length
and not path['terminals'][-1]
and discard_incomplete_paths
):
break
num_steps_collected += path_len
paths.append(path)
self._num_paths_total += len(paths)
self._num_steps_total += num_steps_collected
self._epoch_paths.extend(paths)
return paths
def get_epoch_paths(self):
return self._epoch_paths
def end_epoch(self, epoch):
self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
def get_diagnostics(self):
path_lens = [len(path['actions']) for path in self._epoch_paths]
stats = OrderedDict([
('num steps total', self._num_steps_total),
('num paths total', self._num_paths_total),
])
stats.update(create_stats_ordered_dict(
"path length",
path_lens,
always_show_all_stats=True,
))
return stats
def get_snapshot(self):
return dict(
env=self._env,
policy=self._policy,
)
class EnsembleMdpPathCollector(PathCollector):
def __init__(
self,
env,
policy,
num_ensemble,
noise_flag=0,
ber_mean=0.5,
eval_flag=False,
max_num_epoch_paths_saved=None,
render=False,
render_kwargs=None,
critic1=None,
critic2=None,
inference_type=0.0,
feedback_type=1,
):
if render_kwargs is None:
render_kwargs = {}
self._env = env
self._policy = policy
self._max_num_epoch_paths_saved = max_num_epoch_paths_saved
self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
self._render = render
self._render_kwargs = render_kwargs
self.num_ensemble = num_ensemble
self.eval_flag = eval_flag
self.ber_mean = ber_mean
self.critic1 = critic1
self.critic2 = critic2
self.inference_type = inference_type
self.feedback_type = feedback_type
self._noise_flag = noise_flag
self._num_steps_total = 0
self._num_paths_total = 0
def collect_new_paths(
self,
max_path_length,
num_steps,
discard_incomplete_paths,
):
paths = []
num_steps_collected = 0
while num_steps_collected < num_steps:
max_path_length_this_loop = min( # Do not go over num_steps
max_path_length,
num_steps - num_steps_collected,
)
if self.eval_flag:
path = ensemble_eval_rollout(
self._env,
self._policy,
self.num_ensemble,
max_path_length=max_path_length_this_loop,
)
else:
if self.inference_type > 0: # UCB
path = ensemble_ucb_rollout(
self._env,
self._policy,
critic1=self.critic1,
critic2=self.critic2,
inference_type=self.inference_type,
feedback_type=self.feedback_type,
num_ensemble=self.num_ensemble,
noise_flag=self._noise_flag,
max_path_length=max_path_length_this_loop,
ber_mean=self.ber_mean,
)
else:
path = ensemble_rollout(
self._env,
self._policy,
self.num_ensemble,
noise_flag=self._noise_flag,
max_path_length=max_path_length_this_loop,
ber_mean=self.ber_mean,
)
path_len = len(path['actions'])
if (
path_len != max_path_length
and not path['terminals'][-1]
and discard_incomplete_paths
):
break
num_steps_collected += path_len
paths.append(path)
self._num_paths_total += len(paths)
self._num_steps_total += num_steps_collected
self._epoch_paths.extend(paths)
return paths
def get_epoch_paths(self):
return self._epoch_paths
def end_epoch(self, epoch):
self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
def get_diagnostics(self):
path_lens = [len(path['actions']) for path in self._epoch_paths]
stats = OrderedDict([
('num steps total', self._num_steps_total),
('num paths total', self._num_paths_total),
])
stats.update(create_stats_ordered_dict(
"path length",
path_lens,
always_show_all_stats=True,
))
return stats
def get_snapshot(self):
return dict(
env=self._env,
policy=self._policy,
)
class AsyncEnsembleMdpPathCollector(PathCollector):
def __init__(
self,
env,
policy,
num_ensemble,
noise_flag=0,
ber_mean=0.5,
eval_flag=False,
max_num_epoch_paths_saved=None,
render=False,
render_kwargs=None,
critic1=None,
critic2=None,
inference_type=0.0,
feedback_type=1,
):
if render_kwargs is None:
render_kwargs = {}
self._env = env
self._policy = policy
self._max_num_epoch_paths_saved = max_num_epoch_paths_saved
self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
self._render = render
self._render_kwargs = render_kwargs
self.num_ensemble = num_ensemble
self.eval_flag = eval_flag
self.ber_mean = ber_mean
self.critic1 = critic1
self.critic2 = critic2
self.inference_type = inference_type
self.feedback_type = feedback_type
self._noise_flag = noise_flag
self._num_steps_total = 0
self._num_paths_total = 0
def collect_new_paths(
self,
max_path_length,
num_steps,
discard_incomplete_paths,
):
paths = []
num_steps_collected = 0
while num_steps_collected < num_steps:
max_path_length_this_loop = min( # Do not go over num_steps
max_path_length,
num_steps - num_steps_collected,
)
if self.eval_flag:
path = ensemble_eval_rollout(
self._env,
self._policy,
self.num_ensemble,
max_path_length=max_path_length_this_loop,
)
else:
if self.inference_type > 0: # UCB
path = ensemble_ucb_rollout(
self._env,
self._policy,
critic1=self.critic1,
critic2=self.critic2,
inference_type=self.inference_type,
feedback_type=self.feedback_type,
num_ensemble=self.num_ensemble,
noise_flag=self._noise_flag,
max_path_length=max_path_length_this_loop,
ber_mean=self.ber_mean,
)
else:
path = ensemble_rollout(
self._env,
self._policy,
self.num_ensemble,
noise_flag=self._noise_flag,
max_path_length=max_path_length_this_loop,
ber_mean=self.ber_mean,
)
path_len = len(path['actions'])
if (
path_len != max_path_length
and not path['terminals'][-1]
and discard_incomplete_paths
):
break
num_steps_collected += path_len
paths.append(path)
self._num_paths_total += len(paths)
self._num_steps_total += num_steps_collected
self._epoch_paths.extend(paths)
return paths