Spaces:
Sleeping
Sleeping
import abc | |
from collections import OrderedDict | |
from typing import Iterable | |
from torch import nn as nn | |
from src.rlkit.core.batch_rl_algorithm import BatchRLAlgorithm | |
from src.rlkit.core.batch_normalized_rl_algorithm import BatchNormalRLAlgorithm | |
from src.rlkit.core.online_rl_algorithm import OnlineRLAlgorithm | |
from src.rlkit.core.trainer import Trainer | |
from src.rlkit.torch.core import np_to_pytorch_batch | |
class TorchOnlineRLAlgorithm(OnlineRLAlgorithm): | |
def to(self, device): | |
for net in self.trainer.networks: | |
net.to(device) | |
def training_mode(self, mode): | |
for net in self.trainer.networks: | |
net.train(mode) | |
class TorchBatchRLAlgorithm(BatchRLAlgorithm): | |
def to(self, device): | |
for net in self.trainer.networks: | |
net.to(device) | |
def training_mode(self, mode): | |
for net in self.trainer.networks: | |
net.train(mode) | |
class TorchBatchNormalRLAlgorithm(BatchNormalRLAlgorithm): | |
def to(self, device): | |
for net in self.trainer.networks: | |
net.to(device) | |
def training_mode(self, mode): | |
for net in self.trainer.networks: | |
net.train(mode) | |
class TorchTrainer(Trainer, metaclass=abc.ABCMeta): | |
def __init__(self): | |
self._num_train_steps = 0 | |
def train(self, np_batch): | |
self._num_train_steps += 1 | |
batch = np_to_pytorch_batch(np_batch) | |
self.train_from_torch(batch) | |
def get_diagnostics(self): | |
return OrderedDict([ | |
('num train calls', self._num_train_steps), | |
]) | |
def train_from_torch(self, batch): | |
pass | |
def networks(self) -> Iterable[nn.Module]: | |
pass | |