PySR / lgga /Transformation.py
AutonLabTruth's picture
Added Truth Python Files
b865169
raw
history blame
1.98 kB
class Transformation:
def __init__(self, index, name="Identity Transformation"):
self.name = name
self.index = index
def transform(self, X):
"""
Takes in a data point of shape (n, d) and returns an augmented data point based on the constraint
"""
return X
def __str__(self):
return str(self.name)
def __repr__(self):
return str(self)
def get_params(self):
raise NotImplementedError
class SymTransformation(Transformation):
def __init__(self, x1=0, x2=1):
"""
x1, x2 = indices of the variables which are symmetric
"""
super().__init__(1, name=f"Symmetry Between Variable {x1} and {x2}")
self.x1 = x1
self.x2 = x2
def transform(self, X):
"""
"""
temp = X.copy()
temp[:, self.x2] = X[:, self.x1].copy()
temp[:, self.x1] = X[:, self.x2].copy()
return temp
def get_params(self):
return [self.x1, self.x2]
class ZeroTransformation(Transformation):
def __init__(self, inds=[0]):
"""
inds is a list of indices to set to 0
"""
super().__init__(2, name=f"Zero Constraint for Variables {inds}")
self.inds = inds
def transform(self, X):
temp = X.copy()
for ind in self.inds:
temp[:, ind] = 0
return temp
def get_params(self):
return list(self.inds)
class ValueTransformation(Transformation):
def __init__(self, inds=[0]):
"""
inds is list of indices to set to the same value as the first element in that list
"""
super().__init__(3, name=f"Value Constraint for Variables {inds}")
self.inds = inds
def transform(self, X):
temp = X.copy()
val = temp[:, self.inds[0]]
for ind in self.inds[1:]:
temp[:, ind] = val
return temp
def get_params(self):
return list(self.inds)