Spaces:
Sleeping
Sleeping
import torch | |
import src.rlkit.torch.pytorch_util as ptu | |
import numpy as np | |
from src.rlkit.data_management.normalizer import Normalizer, FixedNormalizer | |
class TorchNormalizer(Normalizer): | |
""" | |
Update with np array, but de/normalize pytorch Tensors. | |
""" | |
def normalize(self, v, clip_range=None): | |
if not self.synchronized: | |
self.synchronize() | |
if clip_range is None: | |
clip_range = self.default_clip_range | |
mean = ptu.from_numpy(self.mean) | |
std = ptu.from_numpy(self.std) | |
if v.dim() == 2: | |
# Unsqueeze along the batch use automatic broadcasting | |
mean = mean.unsqueeze(0) | |
std = std.unsqueeze(0) | |
return torch.clamp((v - mean) / std, -clip_range, clip_range) | |
def denormalize(self, v): | |
if not self.synchronized: | |
self.synchronize() | |
mean = ptu.from_numpy(self.mean) | |
std = ptu.from_numpy(self.std) | |
if v.dim() == 2: | |
mean = mean.unsqueeze(0) | |
std = std.unsqueeze(0) | |
return mean + v * std | |
class TorchFixedNormalizer(FixedNormalizer): | |
def normalize(self, v, clip_range=None): | |
if clip_range is None: | |
clip_range = self.default_clip_range | |
mean = ptu.from_numpy(self.mean) | |
std = ptu.from_numpy(self.std) | |
if v.dim() == 2: | |
# Unsqueeze along the batch use automatic broadcasting | |
mean = mean.unsqueeze(0) | |
std = std.unsqueeze(0) | |
return torch.clamp((v - mean) / std, -clip_range, clip_range) | |
def normalize_scale(self, v): | |
""" | |
Only normalize the scale. Do not subtract the mean. | |
""" | |
std = ptu.from_numpy(self.std) | |
if v.dim() == 2: | |
std = std.unsqueeze(0) | |
return v / std | |
def denormalize(self, v): | |
mean = ptu.from_numpy(self.mean) | |
std = ptu.from_numpy(self.std) | |
if v.dim() == 2: | |
mean = mean.unsqueeze(0) | |
std = std.unsqueeze(0) | |
return mean + v * std | |
def denormalize_scale(self, v): | |
""" | |
Only denormalize the scale. Do not add the mean. | |
""" | |
std = ptu.from_numpy(self.std) | |
if v.dim() == 2: | |
std = std.unsqueeze(0) | |
return v * std | |