import torch import src.rlkit.torch.pytorch_util as ptu import numpy as np from src.rlkit.data_management.normalizer import Normalizer, FixedNormalizer class TorchNormalizer(Normalizer): """ Update with np array, but de/normalize pytorch Tensors. """ def normalize(self, v, clip_range=None): if not self.synchronized: self.synchronize() if clip_range is None: clip_range = self.default_clip_range mean = ptu.from_numpy(self.mean) std = ptu.from_numpy(self.std) if v.dim() == 2: # Unsqueeze along the batch use automatic broadcasting mean = mean.unsqueeze(0) std = std.unsqueeze(0) return torch.clamp((v - mean) / std, -clip_range, clip_range) def denormalize(self, v): if not self.synchronized: self.synchronize() mean = ptu.from_numpy(self.mean) std = ptu.from_numpy(self.std) if v.dim() == 2: mean = mean.unsqueeze(0) std = std.unsqueeze(0) return mean + v * std class TorchFixedNormalizer(FixedNormalizer): def normalize(self, v, clip_range=None): if clip_range is None: clip_range = self.default_clip_range mean = ptu.from_numpy(self.mean) std = ptu.from_numpy(self.std) if v.dim() == 2: # Unsqueeze along the batch use automatic broadcasting mean = mean.unsqueeze(0) std = std.unsqueeze(0) return torch.clamp((v - mean) / std, -clip_range, clip_range) def normalize_scale(self, v): """ Only normalize the scale. Do not subtract the mean. """ std = ptu.from_numpy(self.std) if v.dim() == 2: std = std.unsqueeze(0) return v / std def denormalize(self, v): mean = ptu.from_numpy(self.mean) std = ptu.from_numpy(self.std) if v.dim() == 2: mean = mean.unsqueeze(0) std = std.unsqueeze(0) return mean + v * std def denormalize_scale(self, v): """ Only denormalize the scale. Do not add the mean. """ std = ptu.from_numpy(self.std) if v.dim() == 2: std = std.unsqueeze(0) return v * std