Spaces:
Sleeping
Sleeping
class Transformation: | |
def __init__(self, index, name="Identity Transformation"): | |
self.name = name | |
self.index = index | |
def transform(self, X): | |
""" | |
Takes in a data point of shape (n, d) and returns an augmented data point based on the constraint | |
""" | |
return X | |
def __str__(self): | |
return str(self.name) | |
def __repr__(self): | |
return str(self) | |
def get_params(self): | |
raise NotImplementedError | |
class SymTransformation(Transformation): | |
def __init__(self, x1=0, x2=1): | |
""" | |
x1, x2 = indices of the variables which are symmetric | |
""" | |
super().__init__(1, name=f"Symmetry Between Variable {x1} and {x2}") | |
self.x1 = x1 | |
self.x2 = x2 | |
def transform(self, X): | |
""" | |
""" | |
temp = X.copy() | |
temp[:, self.x2] = X[:, self.x1].copy() | |
temp[:, self.x1] = X[:, self.x2].copy() | |
return temp | |
def get_params(self): | |
return [self.x1, self.x2] | |
class ZeroTransformation(Transformation): | |
def __init__(self, inds=[0]): | |
""" | |
inds is a list of indices to set to 0 | |
""" | |
super().__init__(2, name=f"Zero Constraint for Variables {inds}") | |
self.inds = inds | |
def transform(self, X): | |
temp = X.copy() | |
for ind in self.inds: | |
temp[:, ind] = 0 | |
return temp | |
def get_params(self): | |
return list(self.inds) | |
class ValueTransformation(Transformation): | |
def __init__(self, inds=[0]): | |
""" | |
inds is list of indices to set to the same value as the first element in that list | |
""" | |
super().__init__(3, name=f"Value Constraint for Variables {inds}") | |
self.inds = inds | |
def transform(self, X): | |
temp = X.copy() | |
val = temp[:, self.inds[0]] | |
for ind in self.inds[1:]: | |
temp[:, ind] = val | |
return temp | |
def get_params(self): | |
return list(self.inds) |