|
import argparse |
|
|
|
|
|
from pyvirtualdisplay import Display |
|
virtual_display = Display(visible=0, size=(1400, 900)) |
|
virtual_display.start() |
|
|
|
|
|
import gym |
|
from huggingface_sb3 import load_from_hub, package_to_hub, push_to_hub |
|
from huggingface_hub import notebook_login |
|
from stable_baselines3 import PPO |
|
from stable_baselines3.common.evaluation import evaluate_policy |
|
from stable_baselines3.common.env_util import make_vec_env |
|
from stable_baselines3.common.vec_env import DummyVecEnv |
|
|
|
def str2bool(v): |
|
if isinstance(v, bool): |
|
return v |
|
if v.lower() in ('yes', 'true', 't', 'y', '1'): |
|
return True |
|
elif v.lower() in ('no', 'false', 'f', 'n', '0'): |
|
return False |
|
else: |
|
raise argparse.ArgumentTypeError('Boolean value expected.') |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--model_name', dest='model_name', |
|
default="ppo-LunarLander-v2", type=str, help='model name') |
|
parser.add_argument('--total_timesteps', dest='total_timesteps', |
|
default=1000000, type=int, help='total timesteps') |
|
parser.add_argument('--n_envs', dest='n_envs', |
|
default=16, type=int, help='n_envs') |
|
parser.add_argument('--repo_id', dest='repo_id', |
|
default="thien1892/LunarLander-v2-ppo", type=str, help='repo_id') |
|
parser.add_argument('--commit_message', dest='commit_message', |
|
default="Upload PPO LunarLander-v2 trained agent", type=str, help='commit_message') |
|
parser.add_argument('--re_train', dest='re_train', |
|
default = True, type= str2bool, help='commit_message') |
|
parser.add_argument('--id_retrain', dest='id_retrain', |
|
default="thien1892/LunarLander-v2-ppo-5m", type=str, help='id_retrain') |
|
parser.add_argument('--filename_retrain', dest='filename_retrain', |
|
default="ppo-LunarLander-v2-5m.zip", type=str, help='filename_retrain') |
|
parser.add_argument('--learning_rate', dest='learning_rate', |
|
default=1e-4, type=float, help='learning_rate') |
|
args = parser.parse_args() |
|
|
|
if __name__ == '__main__': |
|
|
|
env = make_vec_env('LunarLander-v2', n_envs= args.n_envs) |
|
|
|
|
|
if not args.re_train : |
|
model = PPO( |
|
policy = 'MlpPolicy', |
|
env = env, |
|
n_steps = 1024, |
|
batch_size = 64, |
|
n_epochs = 4, |
|
gamma = 0.999, |
|
gae_lambda = 0.98, |
|
ent_coef = 0.01, |
|
learning_rate = args.learning_rate, |
|
verbose=1) |
|
else: |
|
checkpoint = load_from_hub(args.id_retrain, args.filename_retrain) |
|
model = PPO.load(checkpoint, reset_num_timesteps=True, print_system_info=True, env = env, learning_rate = args.learning_rate) |
|
|
|
|
|
model.learn(total_timesteps = args.total_timesteps) |
|
|
|
|
|
model.save(args.model_name) |
|
|
|
|
|
eval_env = gym.make("LunarLander-v2") |
|
mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10, deterministic=True) |
|
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}") |
|
|
|
|
|
env_id = "LunarLander-v2" |
|
eval_env = DummyVecEnv([lambda: gym.make(env_id)]) |
|
model_architecture = "PPO" |
|
|
|
package_to_hub(model = model, |
|
model_name = args.model_name, |
|
model_architecture = model_architecture, |
|
env_id = env_id, |
|
eval_env = eval_env, |
|
repo_id = args.repo_id, |
|
commit_message = args.commit_message) |
|
|