NCERL-Diverse-PCG / src /rlkit /torch /torch_rl_algorithm.py
baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
1.75 kB
import abc
from collections import OrderedDict
from typing import Iterable
from torch import nn as nn
from src.rlkit.core.batch_rl_algorithm import BatchRLAlgorithm
from src.rlkit.core.batch_normalized_rl_algorithm import BatchNormalRLAlgorithm
from src.rlkit.core.online_rl_algorithm import OnlineRLAlgorithm
from src.rlkit.core.trainer import Trainer
from src.rlkit.torch.core import np_to_pytorch_batch
class TorchOnlineRLAlgorithm(OnlineRLAlgorithm):
def to(self, device):
for net in self.trainer.networks:
net.to(device)
def training_mode(self, mode):
for net in self.trainer.networks:
net.train(mode)
class TorchBatchRLAlgorithm(BatchRLAlgorithm):
def to(self, device):
for net in self.trainer.networks:
net.to(device)
def training_mode(self, mode):
for net in self.trainer.networks:
net.train(mode)
class TorchBatchNormalRLAlgorithm(BatchNormalRLAlgorithm):
def to(self, device):
for net in self.trainer.networks:
net.to(device)
def training_mode(self, mode):
for net in self.trainer.networks:
net.train(mode)
class TorchTrainer(Trainer, metaclass=abc.ABCMeta):
def __init__(self):
self._num_train_steps = 0
def train(self, np_batch):
self._num_train_steps += 1
batch = np_to_pytorch_batch(np_batch)
self.train_from_torch(batch)
def get_diagnostics(self):
return OrderedDict([
('num train calls', self._num_train_steps),
])
@abc.abstractmethod
def train_from_torch(self, batch):
pass
@property
@abc.abstractmethod
def networks(self) -> Iterable[nn.Module]:
pass