|
""" |
|
Main script to run the Atari Breakout-v0 game. |
|
The DQN algorithm was used to train the agent. |
|
|
|
@author: bvk1ng (Adityam Ghosh) |
|
Date: 12/28/2023 |
|
""" |
|
|
|
from typing import List, Dict, Any, Callable, Tuple, Union |
|
|
|
import numpy as np |
|
import gymnasium as gym |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import albumentations as A |
|
import cv2 |
|
import os |
|
import argparse |
|
|
|
|
|
from model import CNNModel |
|
from utils import play_atari_game, gym |
|
from gymnasium.wrappers.record_video import RecordVideo |
|
|
|
|
|
K = 4 |
|
IM_SIZE = 84 |
|
|
|
|
|
class ImageTransform: |
|
def __init__(self): |
|
self.compose = A.Compose( |
|
[ |
|
A.Crop(x_min=0, y_min=34, x_max=160, y_max=200, always_apply=True), |
|
A.Resize( |
|
height=IM_SIZE, |
|
width=IM_SIZE, |
|
interpolation=cv2.INTER_NEAREST, |
|
always_apply=True, |
|
), |
|
] |
|
) |
|
|
|
def transform(self, img: np.ndarray) -> np.ndarray: |
|
gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) |
|
img_tf = self.compose(image=gray_img) |
|
return img_tf["image"] |
|
|
|
|
|
class DQN: |
|
def __init__( |
|
self, |
|
K: int, |
|
cnn_params: List, |
|
fully_connected_params: List, |
|
device: str = "cuda", |
|
load_path: str = None, |
|
): |
|
self.K = K |
|
self.cnn_model = CNNModel( |
|
K=K, |
|
cnn_params=cnn_params, |
|
fully_connected_params=fully_connected_params, |
|
).to(device=device) |
|
self.device = device |
|
|
|
self.load(load_path) |
|
|
|
def predict(self, states: np.ndarray) -> torch.Tensor: |
|
states = np.transpose(states, (0, 3, 1, 2)) |
|
states = torch.from_numpy(states).float().to(device=self.device) |
|
|
|
states /= 255.0 |
|
|
|
return self.cnn_model(states).detach().cpu() |
|
|
|
def load(self, path: str): |
|
if path is not None: |
|
self.cnn_model.load_state_dict(torch.load(path)) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--model_folder", |
|
"-mF", |
|
type=str, |
|
required=False, |
|
default="./models", |
|
help="the folder to store the models.", |
|
) |
|
parser.add_argument( |
|
"--model_name", |
|
"-mf", |
|
type=str, |
|
required=False, |
|
default="atari_breakout_v0.pt", |
|
help="the name of the model to save.", |
|
) |
|
|
|
parser.add_argument( |
|
"--save_video", |
|
"-s", |
|
type=int, |
|
required=False, |
|
default=0, |
|
help="whether to save a video of the gameplay or not.", |
|
) |
|
|
|
parser.add_argument( |
|
"--video_folder", |
|
"-V", |
|
type=str, |
|
required=False, |
|
default="./videos", |
|
help="where to save the video.", |
|
) |
|
|
|
parser.add_argument( |
|
"--video_name", |
|
"-v", |
|
type=str, |
|
required=False, |
|
default="atari_breakout_v0", |
|
help="the name of the video file.", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
model_folder = args.model_folder |
|
model_name = args.model_name |
|
save_video = args.save_video |
|
video_folder = args.video_folder |
|
video_name = args.video_name |
|
|
|
cnn_params = [(32, 8, 4), (64, 4, 2), (64, 3, 1)] |
|
fully_connected_params = [512] |
|
|
|
load_path = None |
|
|
|
if os.path.exists(os.path.join(model_folder, model_name)): |
|
load_path = os.path.join(model_folder, model_name) |
|
|
|
model = DQN( |
|
K=K, |
|
cnn_params=cnn_params, |
|
fully_connected_params=fully_connected_params, |
|
device="cuda", |
|
lr=1e-5, |
|
load_path=load_path, |
|
) |
|
|
|
img_transformer = ImageTransform() |
|
|
|
if save_video: |
|
env = gym.make("Breakout-v0", render_mode="rgb_array") |
|
env = RecordVideo(env=env, video_folder=video_folder, name_prefix=video_name) |
|
|
|
env.reset() |
|
env.start_video_recorder() |
|
|
|
else: |
|
env = gym.make("Breakout-v0", render_mode="human") |
|
|
|
play_atari_game(env=env, model=model, img_transform=img_transformer) |
|
|