baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
3.73 kB
import os
import os.path as osp
import time
import numpy as np
import scipy.misc
import skvideo.io
from rlkit.envs.vae_wrapper import VAEWrappedEnv
def dump_video(
env,
policy,
filename,
rollout_function,
rows=3,
columns=6,
pad_length=0,
pad_color=255,
do_timer=True,
horizon=100,
dirname_to_save_images=None,
subdirname="rollouts",
imsize=84,
num_channels=3,
):
frames = []
H = 3 * imsize
W = imsize
N = rows * columns
for i in range(N):
start = time.time()
path = rollout_function(
env,
policy,
max_path_length=horizon,
render=False,
)
is_vae_env = isinstance(env, VAEWrappedEnv)
l = []
for d in path['full_observations']:
if is_vae_env:
recon = np.clip(env._reconstruct_img(d['image_observation']), 0,
1)
else:
recon = d['image_observation']
l.append(
get_image(
d['image_desired_goal'],
d['image_observation'],
recon,
pad_length=pad_length,
pad_color=pad_color,
imsize=imsize,
)
)
frames += l
if dirname_to_save_images:
rollout_dir = osp.join(dirname_to_save_images, subdirname, str(i))
os.makedirs(rollout_dir, exist_ok=True)
rollout_frames = frames[-101:]
goal_img = np.flip(rollout_frames[0][:imsize, :imsize, :], 0)
scipy.misc.imsave(rollout_dir + "/goal.png", goal_img)
goal_img = np.flip(rollout_frames[1][:imsize, :imsize, :], 0)
scipy.misc.imsave(rollout_dir + "/z_goal.png", goal_img)
for j in range(0, 101, 1):
img = np.flip(rollout_frames[j][imsize:, :imsize, :], 0)
scipy.misc.imsave(rollout_dir + "/" + str(j) + ".png", img)
if do_timer:
print(i, time.time() - start)
frames = np.array(frames, dtype=np.uint8)
path_length = frames.size // (
N * (H + 2 * pad_length) * (W + 2 * pad_length) * num_channels
)
frames = np.array(frames, dtype=np.uint8).reshape(
(N, path_length, H + 2 * pad_length, W + 2 * pad_length, num_channels)
)
f1 = []
for k1 in range(columns):
f2 = []
for k2 in range(rows):
k = k1 * rows + k2
f2.append(frames[k:k + 1, :, :, :, :].reshape(
(path_length, H + 2 * pad_length, W + 2 * pad_length,
num_channels)
))
f1.append(np.concatenate(f2, axis=1))
outputdata = np.concatenate(f1, axis=2)
skvideo.io.vwrite(filename, outputdata)
print("Saved video to ", filename)
def get_image(goal, obs, recon_obs, imsize=84, pad_length=1, pad_color=255):
if len(goal.shape) == 1:
goal = goal.reshape(-1, imsize, imsize).transpose()
obs = obs.reshape(-1, imsize, imsize).transpose()
recon_obs = recon_obs.reshape(-1, imsize, imsize).transpose()
img = np.concatenate((goal, obs, recon_obs))
img = np.uint8(255 * img)
if pad_length > 0:
img = add_border(img, pad_length, pad_color)
return img
def add_border(img, pad_length, pad_color, imsize=84):
H = 3 * imsize
W = imsize
img = img.reshape((3 * imsize, imsize, -1))
img2 = np.ones((H + 2 * pad_length, W + 2 * pad_length, img.shape[2]),
dtype=np.uint8) * pad_color
img2[pad_length:-pad_length, pad_length:-pad_length, :] = img
return img2