|
import os |
|
import tempfile |
|
import imageio |
|
from stable_baselines3.common.vec_env import VecVideoRecorder |
|
import numpy as np |
|
import gymnasium as gym |
|
from stable_baselines3 import PPO |
|
from stable_baselines3.common.monitor import Monitor |
|
from stable_baselines3.common.vec_env import DummyVecEnv |
|
|
|
|
|
def generate_video(model, video_fp, video_length_in_episodes=5): |
|
|
|
eval_env = model.get_env() |
|
|
|
max_video_length_in_steps = ( |
|
video_length_in_episodes * eval_env.get_attr("spec")[0].max_episode_steps |
|
) |
|
|
|
with tempfile.TemporaryDirectory() as temp_dp: |
|
vec_env = VecVideoRecorder( |
|
eval_env, |
|
temp_dp, |
|
record_video_trigger=lambda x: x == 0, |
|
video_length=max_video_length_in_steps, |
|
) |
|
|
|
frame_count = 0 |
|
episode_count = 0 |
|
obs = vec_env.reset() |
|
for _ in range(max_video_length_in_steps): |
|
action, _ = model.predict(obs, deterministic=True) |
|
obs, _, dones, _ = vec_env.step(action) |
|
frame_count += 1 |
|
if dones: |
|
episode_count += 1 |
|
if episode_count >= video_length_in_episodes: |
|
break |
|
|
|
vec_env.close() |
|
|
|
temp_fp = vec_env.video_recorder.path |
|
|
|
|
|
|
|
os.system( |
|
f"""ffmpeg -y -i {temp_fp} -vf "select='not(eq(n,{frame_count}))'" {video_fp} > /dev/null 2>&1""" |
|
) |
|
|
|
|
|
|
|
def generate_gif(model, file_path, video_length_in_episodes=5): |
|
eval_env = model.get_env() |
|
|
|
max_video_length_in_steps = ( |
|
video_length_in_episodes * eval_env.get_attr("spec")[0].max_episode_steps |
|
) |
|
|
|
render_image = lambda: eval_env.render(mode="rgb_array") |
|
|
|
images = [] |
|
episode_count = 0 |
|
obs = eval_env.reset() |
|
images.append(render_image()) |
|
for _ in range(max_video_length_in_steps): |
|
action, _ = model.predict(obs) |
|
obs, _, dones, _ = eval_env.step(action) |
|
if dones: |
|
episode_count += 1 |
|
if episode_count >= video_length_in_episodes: |
|
break |
|
images.append(render_image()) |
|
|
|
imageio.mimsave( |
|
file_path, [np.array(img) for i, img in enumerate(images) if i % 2 == 0], fps=25 |
|
) |
|
|
|
|
|
def load_ppo_model_for_video(model_fp, env_id): |
|
env = DummyVecEnv([lambda: Monitor(gym.make(env_id, render_mode="rgb_array"))]) |
|
model = PPO.load(model_fp, env=env) |
|
return model |
|
|