ml-reinforcement-learning / src /procgen /run-starpilot-dqn.py
00BER's picture
Upload 36 files
e085e3b
raw
history blame contribute delete
963 Bytes
import os
import torch
from pathlib import Path
from agent import DQNAgent, MetricLogger
from wrappers import make_starpilot
import os
from train import train, fill_memory
env = make_starpilot()
use_cuda = torch.cuda.is_available()
print(f"Using CUDA: {use_cuda}\n")
checkpoint = None
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
path = "checkpoints/procgen-starpilot-dqn"
save_dir = Path(path)
isExist = os.path.exists(path)
if not isExist:
os.makedirs(path)
logger = MetricLogger(save_dir)
print("Training Vanilla DQN Agent!")
agent = DQNAgent(
state_dim=(1, 64, 64),
action_dim=env.action_space.n,
save_dir=save_dir,
batch_size=256,
checkpoint=checkpoint,
exploration_rate_decay=0.999995,
exploration_rate_min=0.05,
training_frequency=1,
target_network_sync_frequency=200,
max_memory_size=50000,
learning_rate=0.0005,
)
fill_memory(agent, env, 300)
train(agent, env, logger)