Spaces:
Running
Running
File size: 1,978 Bytes
b865169 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
class Transformation:
def __init__(self, index, name="Identity Transformation"):
self.name = name
self.index = index
def transform(self, X):
"""
Takes in a data point of shape (n, d) and returns an augmented data point based on the constraint
"""
return X
def __str__(self):
return str(self.name)
def __repr__(self):
return str(self)
def get_params(self):
raise NotImplementedError
class SymTransformation(Transformation):
def __init__(self, x1=0, x2=1):
"""
x1, x2 = indices of the variables which are symmetric
"""
super().__init__(1, name=f"Symmetry Between Variable {x1} and {x2}")
self.x1 = x1
self.x2 = x2
def transform(self, X):
"""
"""
temp = X.copy()
temp[:, self.x2] = X[:, self.x1].copy()
temp[:, self.x1] = X[:, self.x2].copy()
return temp
def get_params(self):
return [self.x1, self.x2]
class ZeroTransformation(Transformation):
def __init__(self, inds=[0]):
"""
inds is a list of indices to set to 0
"""
super().__init__(2, name=f"Zero Constraint for Variables {inds}")
self.inds = inds
def transform(self, X):
temp = X.copy()
for ind in self.inds:
temp[:, ind] = 0
return temp
def get_params(self):
return list(self.inds)
class ValueTransformation(Transformation):
def __init__(self, inds=[0]):
"""
inds is list of indices to set to the same value as the first element in that list
"""
super().__init__(3, name=f"Value Constraint for Variables {inds}")
self.inds = inds
def transform(self, X):
temp = X.copy()
val = temp[:, self.inds[0]]
for ind in self.inds[1:]:
temp[:, ind] = val
return temp
def get_params(self):
return list(self.inds) |