Spaces:
Sleeping
Sleeping
import os | |
import os.path as osp | |
import time | |
import numpy as np | |
import scipy.misc | |
import skvideo.io | |
from rlkit.envs.vae_wrapper import VAEWrappedEnv | |
def dump_video( | |
env, | |
policy, | |
filename, | |
rollout_function, | |
rows=3, | |
columns=6, | |
pad_length=0, | |
pad_color=255, | |
do_timer=True, | |
horizon=100, | |
dirname_to_save_images=None, | |
subdirname="rollouts", | |
imsize=84, | |
num_channels=3, | |
): | |
frames = [] | |
H = 3 * imsize | |
W = imsize | |
N = rows * columns | |
for i in range(N): | |
start = time.time() | |
path = rollout_function( | |
env, | |
policy, | |
max_path_length=horizon, | |
render=False, | |
) | |
is_vae_env = isinstance(env, VAEWrappedEnv) | |
l = [] | |
for d in path['full_observations']: | |
if is_vae_env: | |
recon = np.clip(env._reconstruct_img(d['image_observation']), 0, | |
1) | |
else: | |
recon = d['image_observation'] | |
l.append( | |
get_image( | |
d['image_desired_goal'], | |
d['image_observation'], | |
recon, | |
pad_length=pad_length, | |
pad_color=pad_color, | |
imsize=imsize, | |
) | |
) | |
frames += l | |
if dirname_to_save_images: | |
rollout_dir = osp.join(dirname_to_save_images, subdirname, str(i)) | |
os.makedirs(rollout_dir, exist_ok=True) | |
rollout_frames = frames[-101:] | |
goal_img = np.flip(rollout_frames[0][:imsize, :imsize, :], 0) | |
scipy.misc.imsave(rollout_dir + "/goal.png", goal_img) | |
goal_img = np.flip(rollout_frames[1][:imsize, :imsize, :], 0) | |
scipy.misc.imsave(rollout_dir + "/z_goal.png", goal_img) | |
for j in range(0, 101, 1): | |
img = np.flip(rollout_frames[j][imsize:, :imsize, :], 0) | |
scipy.misc.imsave(rollout_dir + "/" + str(j) + ".png", img) | |
if do_timer: | |
print(i, time.time() - start) | |
frames = np.array(frames, dtype=np.uint8) | |
path_length = frames.size // ( | |
N * (H + 2 * pad_length) * (W + 2 * pad_length) * num_channels | |
) | |
frames = np.array(frames, dtype=np.uint8).reshape( | |
(N, path_length, H + 2 * pad_length, W + 2 * pad_length, num_channels) | |
) | |
f1 = [] | |
for k1 in range(columns): | |
f2 = [] | |
for k2 in range(rows): | |
k = k1 * rows + k2 | |
f2.append(frames[k:k + 1, :, :, :, :].reshape( | |
(path_length, H + 2 * pad_length, W + 2 * pad_length, | |
num_channels) | |
)) | |
f1.append(np.concatenate(f2, axis=1)) | |
outputdata = np.concatenate(f1, axis=2) | |
skvideo.io.vwrite(filename, outputdata) | |
print("Saved video to ", filename) | |
def get_image(goal, obs, recon_obs, imsize=84, pad_length=1, pad_color=255): | |
if len(goal.shape) == 1: | |
goal = goal.reshape(-1, imsize, imsize).transpose() | |
obs = obs.reshape(-1, imsize, imsize).transpose() | |
recon_obs = recon_obs.reshape(-1, imsize, imsize).transpose() | |
img = np.concatenate((goal, obs, recon_obs)) | |
img = np.uint8(255 * img) | |
if pad_length > 0: | |
img = add_border(img, pad_length, pad_color) | |
return img | |
def add_border(img, pad_length, pad_color, imsize=84): | |
H = 3 * imsize | |
W = imsize | |
img = img.reshape((3 * imsize, imsize, -1)) | |
img2 = np.ones((H + 2 * pad_length, W + 2 * pad_length, img.shape[2]), | |
dtype=np.uint8) * pad_color | |
img2[pad_length:-pad_length, pad_length:-pad_length, :] = img | |
return img2 | |