jostyposty's picture
feat: add four models
3261e0d
import os
import tempfile
import imageio
from stable_baselines3.common.vec_env import VecVideoRecorder
import numpy as np
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
def generate_video(model, video_fp, video_length_in_episodes=5):
eval_env = model.get_env()
max_video_length_in_steps = (
video_length_in_episodes * eval_env.get_attr("spec")[0].max_episode_steps
)
with tempfile.TemporaryDirectory() as temp_dp:
vec_env = VecVideoRecorder(
eval_env,
temp_dp,
record_video_trigger=lambda x: x == 0,
video_length=max_video_length_in_steps,
)
frame_count = 0
episode_count = 0
obs = vec_env.reset()
for _ in range(max_video_length_in_steps):
action, _ = model.predict(obs, deterministic=True)
obs, _, dones, _ = vec_env.step(action)
frame_count += 1
if dones:
episode_count += 1
if episode_count >= video_length_in_episodes:
break
vec_env.close()
temp_fp = vec_env.video_recorder.path
# TODO: Fix this.
# Use ffmpeg to remove the last frame (it is the first frame in a new episode).
os.system(
f"""ffmpeg -y -i {temp_fp} -vf "select='not(eq(n,{frame_count}))'" {video_fp} > /dev/null 2>&1"""
)
# os.rename(temp_fp, file_path)
def generate_gif(model, file_path, video_length_in_episodes=5):
eval_env = model.get_env()
max_video_length_in_steps = (
video_length_in_episodes * eval_env.get_attr("spec")[0].max_episode_steps
)
render_image = lambda: eval_env.render(mode="rgb_array")
images = []
episode_count = 0
obs = eval_env.reset()
images.append(render_image())
for _ in range(max_video_length_in_steps):
action, _ = model.predict(obs)
obs, _, dones, _ = eval_env.step(action)
if dones:
episode_count += 1
if episode_count >= video_length_in_episodes:
break
images.append(render_image())
imageio.mimsave(
file_path, [np.array(img) for i, img in enumerate(images) if i % 2 == 0], fps=25
)
def load_ppo_model_for_video(model_fp, env_id):
env = DummyVecEnv([lambda: Monitor(gym.make(env_id, render_mode="rgb_array"))])
model = PPO.load(model_fp, env=env)
return model