Spaces:
Sleeping
Sleeping
File size: 2,278 Bytes
eaf2e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
import src.rlkit.torch.pytorch_util as ptu
import numpy as np
from src.rlkit.data_management.normalizer import Normalizer, FixedNormalizer
class TorchNormalizer(Normalizer):
"""
Update with np array, but de/normalize pytorch Tensors.
"""
def normalize(self, v, clip_range=None):
if not self.synchronized:
self.synchronize()
if clip_range is None:
clip_range = self.default_clip_range
mean = ptu.from_numpy(self.mean)
std = ptu.from_numpy(self.std)
if v.dim() == 2:
# Unsqueeze along the batch use automatic broadcasting
mean = mean.unsqueeze(0)
std = std.unsqueeze(0)
return torch.clamp((v - mean) / std, -clip_range, clip_range)
def denormalize(self, v):
if not self.synchronized:
self.synchronize()
mean = ptu.from_numpy(self.mean)
std = ptu.from_numpy(self.std)
if v.dim() == 2:
mean = mean.unsqueeze(0)
std = std.unsqueeze(0)
return mean + v * std
class TorchFixedNormalizer(FixedNormalizer):
def normalize(self, v, clip_range=None):
if clip_range is None:
clip_range = self.default_clip_range
mean = ptu.from_numpy(self.mean)
std = ptu.from_numpy(self.std)
if v.dim() == 2:
# Unsqueeze along the batch use automatic broadcasting
mean = mean.unsqueeze(0)
std = std.unsqueeze(0)
return torch.clamp((v - mean) / std, -clip_range, clip_range)
def normalize_scale(self, v):
"""
Only normalize the scale. Do not subtract the mean.
"""
std = ptu.from_numpy(self.std)
if v.dim() == 2:
std = std.unsqueeze(0)
return v / std
def denormalize(self, v):
mean = ptu.from_numpy(self.mean)
std = ptu.from_numpy(self.std)
if v.dim() == 2:
mean = mean.unsqueeze(0)
std = std.unsqueeze(0)
return mean + v * std
def denormalize_scale(self, v):
"""
Only denormalize the scale. Do not add the mean.
"""
std = ptu.from_numpy(self.std)
if v.dim() == 2:
std = std.unsqueeze(0)
return v * std
|