Spaces:
Sleeping
Sleeping
from collections import deque, OrderedDict | |
from src.rlkit.core.eval_util import create_stats_ordered_dict | |
from src.rlkit.samplers.rollout_functions import rollout, multitask_rollout, ensemble_rollout, ensemble_eval_rollout | |
from src.rlkit.samplers.rollout_functions import ensemble_ucb_rollout | |
from src.rlkit.samplers.data_collector.base import PathCollector | |
class MdpPathCollector(PathCollector): | |
def __init__( | |
self, | |
env, | |
policy, | |
noise_flag=0, | |
max_num_epoch_paths_saved=None, | |
render=False, | |
render_kwargs=None, | |
): | |
if render_kwargs is None: | |
render_kwargs = {} | |
self._env = env | |
self._policy = policy | |
self._max_num_epoch_paths_saved = max_num_epoch_paths_saved | |
self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) | |
self._render = render | |
self._render_kwargs = render_kwargs | |
self._noise_flag = noise_flag | |
self._num_steps_total = 0 | |
self._num_paths_total = 0 | |
def collect_new_paths( | |
self, | |
max_path_length, | |
num_steps, | |
discard_incomplete_paths, | |
): | |
paths = [] | |
num_steps_collected = 0 | |
while num_steps_collected < num_steps: | |
max_path_length_this_loop = min( # Do not go over num_steps | |
max_path_length, | |
num_steps - num_steps_collected, | |
) | |
path = rollout( | |
self._env, | |
self._policy, | |
noise_flag=self._noise_flag, | |
max_path_length=max_path_length_this_loop, | |
) | |
path_len = len(path['actions']) | |
if ( | |
path_len != max_path_length | |
and not path['terminals'][-1] | |
and discard_incomplete_paths | |
): | |
break | |
num_steps_collected += path_len | |
paths.append(path) | |
self._num_paths_total += len(paths) | |
self._num_steps_total += num_steps_collected | |
self._epoch_paths.extend(paths) | |
return paths | |
def collect_normalized_new_paths( | |
self, | |
max_path_length, | |
num_steps, | |
discard_incomplete_paths, | |
input_mean, | |
input_std, | |
): | |
paths = [] | |
num_steps_collected = 0 | |
while num_steps_collected < num_steps: | |
max_path_length_this_loop = min( # Do not go over num_steps | |
max_path_length, | |
num_steps - num_steps_collected, | |
) | |
path = normalized_rollout( | |
self._env, | |
self._policy, | |
noise_flag=self._noise_flag, | |
max_path_length=max_path_length_this_loop, | |
input_mean=input_mean, | |
input_std=input_std, | |
) | |
path_len = len(path['actions']) | |
if ( | |
path_len != max_path_length | |
and not path['terminals'][-1] | |
and discard_incomplete_paths | |
): | |
break | |
num_steps_collected += path_len | |
paths.append(path) | |
self._num_paths_total += len(paths) | |
self._num_steps_total += num_steps_collected | |
self._epoch_paths.extend(paths) | |
return paths | |
def get_epoch_paths(self): | |
return self._epoch_paths | |
def end_epoch(self, epoch): | |
self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) | |
def get_diagnostics(self): | |
path_lens = [len(path['actions']) for path in self._epoch_paths] | |
stats = OrderedDict([ | |
('num steps total', self._num_steps_total), | |
('num paths total', self._num_paths_total), | |
]) | |
stats.update(create_stats_ordered_dict( | |
"path length", | |
path_lens, | |
always_show_all_stats=True, | |
)) | |
return stats | |
def get_snapshot(self): | |
return dict( | |
env=self._env, | |
policy=self._policy, | |
) | |
class EnsembleMdpPathCollector(PathCollector): | |
def __init__( | |
self, | |
env, | |
policy, | |
num_ensemble, | |
noise_flag=0, | |
ber_mean=0.5, | |
eval_flag=False, | |
max_num_epoch_paths_saved=None, | |
render=False, | |
render_kwargs=None, | |
critic1=None, | |
critic2=None, | |
inference_type=0.0, | |
feedback_type=1, | |
): | |
if render_kwargs is None: | |
render_kwargs = {} | |
self._env = env | |
self._policy = policy | |
self._max_num_epoch_paths_saved = max_num_epoch_paths_saved | |
self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) | |
self._render = render | |
self._render_kwargs = render_kwargs | |
self.num_ensemble = num_ensemble | |
self.eval_flag = eval_flag | |
self.ber_mean = ber_mean | |
self.critic1 = critic1 | |
self.critic2 = critic2 | |
self.inference_type = inference_type | |
self.feedback_type = feedback_type | |
self._noise_flag = noise_flag | |
self._num_steps_total = 0 | |
self._num_paths_total = 0 | |
def collect_new_paths( | |
self, | |
max_path_length, | |
num_steps, | |
discard_incomplete_paths, | |
): | |
paths = [] | |
num_steps_collected = 0 | |
while num_steps_collected < num_steps: | |
max_path_length_this_loop = min( # Do not go over num_steps | |
max_path_length, | |
num_steps - num_steps_collected, | |
) | |
if self.eval_flag: | |
path = ensemble_eval_rollout( | |
self._env, | |
self._policy, | |
self.num_ensemble, | |
max_path_length=max_path_length_this_loop, | |
) | |
else: | |
if self.inference_type > 0: # UCB | |
path = ensemble_ucb_rollout( | |
self._env, | |
self._policy, | |
critic1=self.critic1, | |
critic2=self.critic2, | |
inference_type=self.inference_type, | |
feedback_type=self.feedback_type, | |
num_ensemble=self.num_ensemble, | |
noise_flag=self._noise_flag, | |
max_path_length=max_path_length_this_loop, | |
ber_mean=self.ber_mean, | |
) | |
else: | |
path = ensemble_rollout( | |
self._env, | |
self._policy, | |
self.num_ensemble, | |
noise_flag=self._noise_flag, | |
max_path_length=max_path_length_this_loop, | |
ber_mean=self.ber_mean, | |
) | |
path_len = len(path['actions']) | |
if ( | |
path_len != max_path_length | |
and not path['terminals'][-1] | |
and discard_incomplete_paths | |
): | |
break | |
num_steps_collected += path_len | |
paths.append(path) | |
self._num_paths_total += len(paths) | |
self._num_steps_total += num_steps_collected | |
self._epoch_paths.extend(paths) | |
return paths | |
def get_epoch_paths(self): | |
return self._epoch_paths | |
def end_epoch(self, epoch): | |
self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) | |
def get_diagnostics(self): | |
path_lens = [len(path['actions']) for path in self._epoch_paths] | |
stats = OrderedDict([ | |
('num steps total', self._num_steps_total), | |
('num paths total', self._num_paths_total), | |
]) | |
stats.update(create_stats_ordered_dict( | |
"path length", | |
path_lens, | |
always_show_all_stats=True, | |
)) | |
return stats | |
def get_snapshot(self): | |
return dict( | |
env=self._env, | |
policy=self._policy, | |
) | |
class AsyncEnsembleMdpPathCollector(PathCollector): | |
def __init__( | |
self, | |
env, | |
policy, | |
num_ensemble, | |
noise_flag=0, | |
ber_mean=0.5, | |
eval_flag=False, | |
max_num_epoch_paths_saved=None, | |
render=False, | |
render_kwargs=None, | |
critic1=None, | |
critic2=None, | |
inference_type=0.0, | |
feedback_type=1, | |
): | |
if render_kwargs is None: | |
render_kwargs = {} | |
self._env = env | |
self._policy = policy | |
self._max_num_epoch_paths_saved = max_num_epoch_paths_saved | |
self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved) | |
self._render = render | |
self._render_kwargs = render_kwargs | |
self.num_ensemble = num_ensemble | |
self.eval_flag = eval_flag | |
self.ber_mean = ber_mean | |
self.critic1 = critic1 | |
self.critic2 = critic2 | |
self.inference_type = inference_type | |
self.feedback_type = feedback_type | |
self._noise_flag = noise_flag | |
self._num_steps_total = 0 | |
self._num_paths_total = 0 | |
def collect_new_paths( | |
self, | |
max_path_length, | |
num_steps, | |
discard_incomplete_paths, | |
): | |
paths = [] | |
num_steps_collected = 0 | |
while num_steps_collected < num_steps: | |
max_path_length_this_loop = min( # Do not go over num_steps | |
max_path_length, | |
num_steps - num_steps_collected, | |
) | |
if self.eval_flag: | |
path = ensemble_eval_rollout( | |
self._env, | |
self._policy, | |
self.num_ensemble, | |
max_path_length=max_path_length_this_loop, | |
) | |
else: | |
if self.inference_type > 0: # UCB | |
path = ensemble_ucb_rollout( | |
self._env, | |
self._policy, | |
critic1=self.critic1, | |
critic2=self.critic2, | |
inference_type=self.inference_type, | |
feedback_type=self.feedback_type, | |
num_ensemble=self.num_ensemble, | |
noise_flag=self._noise_flag, | |
max_path_length=max_path_length_this_loop, | |
ber_mean=self.ber_mean, | |
) | |
else: | |
path = ensemble_rollout( | |
self._env, | |
self._policy, | |
self.num_ensemble, | |
noise_flag=self._noise_flag, | |
max_path_length=max_path_length_this_loop, | |
ber_mean=self.ber_mean, | |
) | |
path_len = len(path['actions']) | |
if ( | |
path_len != max_path_length | |
and not path['terminals'][-1] | |
and discard_incomplete_paths | |
): | |
break | |
num_steps_collected += path_len | |
paths.append(path) | |
self._num_paths_total += len(paths) | |
self._num_steps_total += num_steps_collected | |
self._epoch_paths.extend(paths) | |
return paths | |