bvk1ng's picture
Stage-1 commit: Agent trained for 3500 episodes
c121225
"""
Main script to run the Atari Breakout-v0 game.
The DQN algorithm was used to train the agent.
@author: bvk1ng (Adityam Ghosh)
Date: 12/28/2023
"""
from typing import List, Dict, Any, Callable, Tuple, Union
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import albumentations as A
import cv2
import os
import argparse
from model import CNNModel
from utils import play_atari_game, gym
from gymnasium.wrappers.record_video import RecordVideo
K = 4
IM_SIZE = 84
class ImageTransform:
def __init__(self):
self.compose = A.Compose(
[
A.Crop(x_min=0, y_min=34, x_max=160, y_max=200, always_apply=True),
A.Resize(
height=IM_SIZE,
width=IM_SIZE,
interpolation=cv2.INTER_NEAREST,
always_apply=True,
),
]
)
def transform(self, img: np.ndarray) -> np.ndarray:
gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
img_tf = self.compose(image=gray_img)
return img_tf["image"]
class DQN:
def __init__(
self,
K: int,
cnn_params: List,
fully_connected_params: List,
device: str = "cuda",
load_path: str = None,
):
self.K = K
self.cnn_model = CNNModel(
K=K,
cnn_params=cnn_params,
fully_connected_params=fully_connected_params,
).to(device=device)
self.device = device
self.load(load_path)
def predict(self, states: np.ndarray) -> torch.Tensor:
states = np.transpose(states, (0, 3, 1, 2)) # (N, T, H, W)
states = torch.from_numpy(states).float().to(device=self.device)
states /= 255.0
return self.cnn_model(states).detach().cpu()
def load(self, path: str):
if path is not None:
self.cnn_model.load_state_dict(torch.load(path))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_folder",
"-mF",
type=str,
required=False,
default="./models",
help="the folder to store the models.",
)
parser.add_argument(
"--model_name",
"-mf",
type=str,
required=False,
default="atari_breakout_v0.pt",
help="the name of the model to save.",
)
parser.add_argument(
"--save_video",
"-s",
type=int,
required=False,
default=0,
help="whether to save a video of the gameplay or not.",
)
parser.add_argument(
"--video_folder",
"-V",
type=str,
required=False,
default="./videos",
help="where to save the video.",
)
parser.add_argument(
"--video_name",
"-v",
type=str,
required=False,
default="atari_breakout_v0",
help="the name of the video file.",
)
args = parser.parse_args()
model_folder = args.model_folder
model_name = args.model_name
save_video = args.save_video
video_folder = args.video_folder
video_name = args.video_name
cnn_params = [(32, 8, 4), (64, 4, 2), (64, 3, 1)]
fully_connected_params = [512]
load_path = None
if os.path.exists(os.path.join(model_folder, model_name)):
load_path = os.path.join(model_folder, model_name)
model = DQN(
K=K,
cnn_params=cnn_params,
fully_connected_params=fully_connected_params,
device="cuda",
lr=1e-5,
load_path=load_path,
)
img_transformer = ImageTransform()
if save_video:
env = gym.make("Breakout-v0", render_mode="rgb_array")
env = RecordVideo(env=env, video_folder=video_folder, name_prefix=video_name)
env.reset()
env.start_video_recorder()
else:
env = gym.make("Breakout-v0", render_mode="human")
play_atari_game(env=env, model=model, img_transform=img_transformer)