File size: 6,487 Bytes
ca85408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import glob
import os
import pickle

import torch
import numpy as np
import gymnasium as gym
from huggingface_hub.utils import EntryNotFoundError
from huggingface_sb3 import load_from_hub
from moviepy.video.compositing.concatenate import concatenate_videoclips
from moviepy.video.io.VideoFileClip import VideoFileClip
from rl_zoo3 import ALGOS
from gymnasium.wrappers import RecordVideo
from stable_baselines3.common.running_mean_std import RunningMeanStd

import os
import tarfile
import urllib.request


def install_mujoco():
    mujoco_url = "https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz"
    mujoco_file = "mujoco210-linux-x86_64.tar.gz"
    mujoco_dir = "mujoco210"

    # Check if the directory already exists
    if not os.path.exists("mujoco210"):
        # Download Mujoco if not exists
        print("Downloading Mujoco...")
        urllib.request.urlretrieve(mujoco_url, mujoco_file)

        # Extract Mujoco
        print("Extracting Mujoco...")
        with tarfile.open(mujoco_file, "r:gz") as tar:
            tar.extractall()

        # Clean up the downloaded tar file
        os.remove(mujoco_file)

        print("Mujoco installed successfully!")
    else:
        print("Mujoco already installed.")

    # Set environment variable MUJOCO_PY_MUJOCO_PATH
    os.environ["MUJOCO_PY_MUJOCO_PATH"] = os.path.abspath(mujoco_dir)

    ld_library_path = os.environ.get("LD_LIBRARY_PATH", "")
    mujoco_bin_path = os.path.join(os.path.abspath(mujoco_dir), "bin")
    if mujoco_bin_path not in ld_library_path:
        os.environ["LD_LIBRARY_PATH"] = ld_library_path + ":" + mujoco_bin_path



class NormalizeObservation(gym.Wrapper):
    def __init__(self, env: gym.Env, clip_obs: float, obs_rms: RunningMeanStd, epsilon: float):
        gym.Wrapper.__init__(self, env)
        self.clip_obs = clip_obs
        self.obs_rms = obs_rms
        self.epsilon = epsilon

    def step(self, action):
        observation, reward, terminated, truncated, info = self.env.step(action)
        observation = self.normalize(np.array([observation]))[0]
        return observation, reward, terminated, truncated, info

    def reset(self, **kwargs):
        observation, info = self.env.reset(**kwargs)
        return self.normalize(np.array([observation]))[0], info

    def normalize(self, obs):
        return np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon), -self.clip_obs, self.clip_obs)


class CreateDataset(gym.Wrapper):
    def __init__(self, env: gym.Env):
        gym.Wrapper.__init__(self, env)
        self.observations = []
        self.actions = []
        self.last_observation = None

    def step(self, action):
        self.observations.append(self.last_observation)
        self.actions.append(action)
        observation, reward, terminated, truncated, info = self.env.step(action)
        self.last_observation = observation
        return observation, reward, terminated, truncated, info

    def reset(self, **kwargs):
        observation, info = self.env.reset(**kwargs)
        self.last_observation = observation
        return observation, info

    def get_dataset(self):
        if isinstance(self.env.action_space, gym.spaces.Box) and self.env.action_space.shape != (1,):
            actions = np.vstack(self.actions)
        else:
            actions = np.hstack(self.actions)
        return np.vstack(self.observations), actions


def rollouts(env, policy, num_episodes=1):
    for episode in range(num_episodes):
        done = False
        observation, _ = env.reset()
        while not done:
            action = policy(observation)
            observation, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
    env.close()


def generate_dataset_from_expert(algo, env_name, num_train_episodes=5, num_test_episodes=2, force=False):
    if env_name.startswith("Swimmer") or env_name.startswith("Hopper"):
        install_mujoco()
    dataset_path = os.path.join("datasets", f"{algo}-{env_name}.pt")
    video_path = os.path.join("videos", f"{algo}-{env_name}.mp4")
    if os.path.exists(dataset_path) and os.path.exists(video_path) and not force:
        return dataset_path, video_path
    repo_id = f"sb3/{algo}-{env_name}"
    policy_file = f"{algo}-{env_name}.zip"

    expert_path = load_from_hub(repo_id, policy_file)
    try:
        vec_normalize_path = load_from_hub(repo_id, "vec_normalize.pkl")
        with open(vec_normalize_path, "rb") as f:
            vec_normalize = pickle.load(f)
            if vec_normalize.norm_obs:
                vec_normalize_params = {"clip_obs": vec_normalize.clip_obs, "obs_rms": vec_normalize.obs_rms, "epsilon": vec_normalize.epsilon}
            else:
                vec_normalize_params = None
    except EntryNotFoundError:
        vec_normalize_params = None

    expert = ALGOS[algo].load(expert_path)
    train_env = gym.make(env_name)
    train_env = CreateDataset(train_env)
    if vec_normalize_params is not None:
        train_env = NormalizeObservation(train_env, **vec_normalize_params)
    test_env = gym.make(env_name, render_mode="rgb_array")
    test_env = CreateDataset(test_env)
    if vec_normalize_params is not None:
        test_env = NormalizeObservation(test_env, **vec_normalize_params)
    test_env = RecordVideo(test_env, video_folder="videos", episode_trigger=lambda x: True, name_prefix=f"{algo}-{env_name}")

    def policy(obs):
        return expert.predict(obs, deterministic=True)[0]

    os.makedirs("videos", exist_ok=True)
    rollouts(train_env, policy, num_train_episodes)
    rollouts(test_env, policy, num_test_episodes)

    train_observations, train_actions = train_env.get_dataset()
    test_observations, test_actions = test_env.get_dataset()

    dataset = {
        "train_input": torch.from_numpy(train_observations),
        "test_input": torch.from_numpy(test_observations),
        "train_label": torch.from_numpy(train_actions),
        "test_label": torch.from_numpy(test_actions)
    }

    os.makedirs("datasets", exist_ok=True)
    torch.save(dataset, dataset_path)

    video_files = glob.glob(os.path.join("videos", f"{algo}-{env_name}-episode*.mp4"))
    clips = [VideoFileClip(file) for file in video_files]
    final_clip = concatenate_videoclips(clips)
    final_clip.write_videofile(video_path, codec="libx264", fps=24)

    return dataset_path, video_path


if __name__ == "__main__":
    generate_dataset_from_expert("ppo", "CartPole-v1", force=True)