ThomasSimonini HF staff commited on
Commit
440f297
1 Parent(s): 41c60c9

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +126 -1
README.md CHANGED
@@ -4,4 +4,129 @@ tags:
4
  - reinforcement-learning
5
  - stable-baselines3
6
  ---
7
- # TODO: Fill this model card
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  - reinforcement-learning
5
  - stable-baselines3
6
  ---
7
+ # PPO Agent playing BreakoutNoFrameskip-v4
8
+ This is a trained model of a **PPO agent playing BreakoutNoFrameskip-v4 using the [stable-baselines3 library](https://stable-baselines3.readthedocs.io/en/master/index.html)**.
9
+
10
+ <video src="https://huggingface.co/ThomasSimonini/ppo-BreakoutNoFrameskip-v4/resolve/main/output.mp4" controls autoplay loop></video>
11
+
12
+ ## Evaluation Results
13
+ Mean_reward: ``
14
+
15
+ # Usage (with Stable-baselines3)
16
+ - You need to use `gym==0.19` since it **includes Atari Roms**.
17
+ - The Action Space is 6 since we use only **possible actions in this game**.
18
+
19
+
20
+ Watch your agent interacts :
21
+
22
+ ```python
23
+ # Import the libraries
24
+ import os
25
+
26
+ import gym
27
+
28
+ from stable_baselines3 import PPO
29
+ from stable_baselines3.common.vec_env import VecNormalize
30
+
31
+ from stable_baselines3.common.env_util import make_atari_env
32
+ from stable_baselines3.common.vec_env import VecFrameStack
33
+
34
+ from huggingface_sb3 import load_from_hub, push_to_hub
35
+
36
+ # Load the model
37
+ checkpoint = load_from_hub("ThomasSimonini/ppo-BreakoutNoFrameskip-v4", "ppo-BreakoutNoFrameskip-v4.zip")
38
+
39
+ # Because we using 3.7 on Colab and this agent was trained with 3.8 to avoid Pickle errors:
40
+ custom_objects = {
41
+ "learning_rate": 0.0,
42
+ "lr_schedule": lambda _: 0.0,
43
+ "clip_range": lambda _: 0.0,
44
+ }
45
+
46
+ model= PPO.load(checkpoint, custom_objects=custom_objects)
47
+
48
+ env = make_atari_env('BreakoutNoFrameskip-v4', n_envs=1)
49
+ env = VecFrameStack(env, n_stack=4)
50
+
51
+ obs = env.reset()
52
+ while True:
53
+ action, _states = model.predict(obs)
54
+ obs, rewards, dones, info = env.step(action)
55
+ env.render()
56
+ ```
57
+
58
+
59
+ ## Training Code
60
+ ```python
61
+ import wandb
62
+ import gym
63
+
64
+ from stable_baselines3 import PPO
65
+ from stable_baselines3.common.env_util import make_atari_env
66
+ from stable_baselines3.common.vec_env import VecFrameStack, VecVideoRecorder
67
+ from stable_baselines3.common.callbacks import CheckpointCallback
68
+
69
+ from wandb.integration.sb3 import WandbCallback
70
+
71
+ from huggingface_sb3 import load_from_hub, push_to_hub
72
+
73
+ config = {
74
+ "env_name": "BreakoutNoFrameskip-v4",
75
+ "num_envs": 8,
76
+ "total_timesteps": int(10e6),
77
+ "seed": 661550378,
78
+ }
79
+
80
+ run = wandb.init(
81
+ project="HFxSB3",
82
+ config = config,
83
+ sync_tensorboard = True, # Auto-upload sb3's tensorboard metrics
84
+ monitor_gym = True, # Auto-upload the videos of agents playing the game
85
+ save_code = True, # Save the code to W&B
86
+ )
87
+
88
+ # There already exists an environment generator
89
+ # that will make and wrap atari environments correctly.
90
+ # Here we are also multi-worker training (n_envs=8 => 8 environments)
91
+ env = make_atari_env(config["env_name"], n_envs=config["num_envs"], seed=config["seed"]) #BreakoutNoFrameskip-v4
92
+
93
+ print("ENV ACTION SPACE: ", env.action_space.n)
94
+
95
+ # Frame-stacking with 4 frames
96
+ env = VecFrameStack(env, n_stack=4)
97
+ # Video recorder
98
+ env = VecVideoRecorder(env, "videos", record_video_trigger=lambda x: x % 100000 == 0, video_length=2000)
99
+
100
+ model = PPO(policy = "CnnPolicy",
101
+ env = env,
102
+ batch_size = 256,
103
+ clip_range = 0.1,
104
+ ent_coef = 0.01,
105
+ gae_lambda = 0.9,
106
+ gamma = 0.99,
107
+ learning_rate = 2.5e-4,
108
+ max_grad_norm = 0.5,
109
+ n_epochs = 4,
110
+ n_steps = 128,
111
+ vf_coef = 0.5,
112
+ tensorboard_log = f"runs",
113
+ verbose=1,
114
+ )
115
+
116
+ model.learn(
117
+ total_timesteps = config["total_timesteps"],
118
+ callback = [
119
+ WandbCallback(
120
+ gradient_save_freq = 1000,
121
+ model_save_path = f"models/{run.id}",
122
+ ),
123
+ CheckpointCallback(save_freq=10000, save_path='./breakout',
124
+ name_prefix=config["env_name"]),
125
+ ]
126
+ )
127
+
128
+ model.save("ppo-BreakoutNoFrameskip-v4.zip")
129
+ push_to_hub(repo_id="ThomasSimonini/ppo-BreakoutNoFrameskip-v4",
130
+ filename="ppo-BreakoutNoFrameskip-v4.zip",
131
+ commit_message="Added Breakout trained agent")
132
+ ```