File size: 963 Bytes
e085e3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
import torch
from pathlib import Path

from agent import DQNAgent, MetricLogger
from wrappers import make_starpilot
import os
from train import train, fill_memory


env = make_starpilot()

use_cuda = torch.cuda.is_available()
print(f"Using CUDA: {use_cuda}\n")

checkpoint = None 
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')

path = "checkpoints/procgen-starpilot-dqn"
save_dir = Path(path) 

isExist = os.path.exists(path)
if not isExist:
   os.makedirs(path)

logger = MetricLogger(save_dir)

print("Training Vanilla DQN Agent!")
agent = DQNAgent(
    state_dim=(1, 64, 64), 
    action_dim=env.action_space.n,
    save_dir=save_dir, 
    batch_size=256,
    checkpoint=checkpoint,  
    exploration_rate_decay=0.999995,
    exploration_rate_min=0.05,
    training_frequency=1, 
    target_network_sync_frequency=200,
    max_memory_size=50000,
    learning_rate=0.0005,

)

fill_memory(agent, env, 300)
train(agent, env, logger)