nithin04 commited on
Commit
96a08c9
·
verified ·
1 Parent(s): 9dffda3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +28 -3
README.md CHANGED
@@ -26,12 +26,37 @@ This is a trained model of a **A2C** agent playing **PandaReachDense-v3**
26
  using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
27
 
28
  ## Usage (with Stable-baselines3)
29
- TODO: Add your code
30
 
31
 
32
  ```python
33
- from stable_baselines3 import ...
34
- from huggingface_sb3 import load_from_hub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  ...
37
  ```
 
26
  using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
27
 
28
  ## Usage (with Stable-baselines3)
29
+
30
 
31
 
32
  ```python
33
+ import os
34
+ import gymnasium as gym
35
+ import panda_gym
36
+ from huggingface_sb3 import load_from_hub, package_to_hub
37
+ from stable_baselines3 import A2C
38
+ from stable_baselines3.common.evaluation import evaluate_policy
39
+ from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
40
+ from stable_baselines3.common.env_util import make_vec_env
41
+ env_id = "PandaReachDense-v3"
42
+ env = gym.make(env_id)
43
+ s_size = env.observation_space.shape
44
+ a_size = env.action_space
45
+ env = make_vec_env(env_id, n_envs=4)
46
+ env = VecNormalize(venv=env, norm_obs=True, norm_reward=True, clip_obs=10)
47
+ model = A2C(policy="MultiInputPolicy", env=env, verbose=1)
48
+ model.learn(1_000_000)
49
+ model.save("a2c-PandaReachDense-v3")
50
+ env.save("vec_normalize.pkl")
51
+ from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
52
+ eval_env = DummyVecEnv([lambda: gym.make("PandaReachDense-v3")])
53
+ eval_env = VecNormalize.load("vec_normalize.pkl", eval_env)
54
+ eval_env.render_mode = "rgb_array"
55
+ eval_env.training = False
56
+ eval_env.norm_reward = False
57
+ model = A2C.load("a2c-PandaReachDense-v3")
58
+ mean_reward, std_reward = evaluate_policy(model, eval_env)
59
+ print(f"Mean reward = {mean_reward:.2f} +/- {std_reward:.2f}")
60
 
61
  ...
62
  ```