baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
2.28 kB
import torch
import src.rlkit.torch.pytorch_util as ptu
import numpy as np
from src.rlkit.data_management.normalizer import Normalizer, FixedNormalizer
class TorchNormalizer(Normalizer):
"""
Update with np array, but de/normalize pytorch Tensors.
"""
def normalize(self, v, clip_range=None):
if not self.synchronized:
self.synchronize()
if clip_range is None:
clip_range = self.default_clip_range
mean = ptu.from_numpy(self.mean)
std = ptu.from_numpy(self.std)
if v.dim() == 2:
# Unsqueeze along the batch use automatic broadcasting
mean = mean.unsqueeze(0)
std = std.unsqueeze(0)
return torch.clamp((v - mean) / std, -clip_range, clip_range)
def denormalize(self, v):
if not self.synchronized:
self.synchronize()
mean = ptu.from_numpy(self.mean)
std = ptu.from_numpy(self.std)
if v.dim() == 2:
mean = mean.unsqueeze(0)
std = std.unsqueeze(0)
return mean + v * std
class TorchFixedNormalizer(FixedNormalizer):
def normalize(self, v, clip_range=None):
if clip_range is None:
clip_range = self.default_clip_range
mean = ptu.from_numpy(self.mean)
std = ptu.from_numpy(self.std)
if v.dim() == 2:
# Unsqueeze along the batch use automatic broadcasting
mean = mean.unsqueeze(0)
std = std.unsqueeze(0)
return torch.clamp((v - mean) / std, -clip_range, clip_range)
def normalize_scale(self, v):
"""
Only normalize the scale. Do not subtract the mean.
"""
std = ptu.from_numpy(self.std)
if v.dim() == 2:
std = std.unsqueeze(0)
return v / std
def denormalize(self, v):
mean = ptu.from_numpy(self.mean)
std = ptu.from_numpy(self.std)
if v.dim() == 2:
mean = mean.unsqueeze(0)
std = std.unsqueeze(0)
return mean + v * std
def denormalize_scale(self, v):
"""
Only denormalize the scale. Do not add the mean.
"""
std = ptu.from_numpy(self.std)
if v.dim() == 2:
std = std.unsqueeze(0)
return v * std