Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
aadb328
1
Parent(s):
18afca5
Add ability to pass strings defining operators
Browse files- pysr/sr.py +29 -8
pysr/sr.py
CHANGED
@@ -92,6 +92,17 @@ def pysr(X=None, y=None, weights=None, threads=4,
|
|
92 |
|
93 |
"""
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
rand_string = f'{"".join([str(np.random.rand())[2] for i in range(20)])}'
|
96 |
|
97 |
if isinstance(binary_operators, str): binary_operators = [binary_operators]
|
@@ -115,7 +126,24 @@ def pysr(X=None, y=None, weights=None, threads=4,
|
|
115 |
|
116 |
pkg_directory = '/'.join(__file__.split('/')[:-2] + ['julia'])
|
117 |
|
118 |
-
def_hyperparams =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
const binops = {'[' + ', '.join(binary_operators) + ']'}
|
120 |
const unaops = {'[' + ', '.join(unary_operators) + ']'}
|
121 |
const ns=10;
|
@@ -144,13 +172,6 @@ const mutationWeights = [
|
|
144 |
]
|
145 |
"""
|
146 |
|
147 |
-
assert len(X.shape) == 2
|
148 |
-
assert len(y.shape) == 1
|
149 |
-
assert X.shape[0] == y.shape[0]
|
150 |
-
if weights is not None:
|
151 |
-
assert len(weights.shape) == 1
|
152 |
-
assert X.shape[0] == weights.shape[0]
|
153 |
-
|
154 |
if X.shape[1] == 1:
|
155 |
X_str = 'transpose([' + str(X.tolist()).replace(']', '').replace(',', '').replace('[', '') + '])'
|
156 |
else:
|
|
|
92 |
|
93 |
"""
|
94 |
|
95 |
+
# Check for potential errors before they happen
|
96 |
+
assert len(binary_operators) > 0
|
97 |
+
assert len(unary_operators) > 0
|
98 |
+
assert len(X.shape) == 2
|
99 |
+
assert len(y.shape) == 1
|
100 |
+
assert X.shape[0] == y.shape[0]
|
101 |
+
if weights is not None:
|
102 |
+
assert len(weights.shape) == 1
|
103 |
+
assert X.shape[0] == weights.shape[0]
|
104 |
+
|
105 |
+
|
106 |
rand_string = f'{"".join([str(np.random.rand())[2] for i in range(20)])}'
|
107 |
|
108 |
if isinstance(binary_operators, str): binary_operators = [binary_operators]
|
|
|
126 |
|
127 |
pkg_directory = '/'.join(__file__.split('/')[:-2] + ['julia'])
|
128 |
|
129 |
+
def_hyperparams = ""
|
130 |
+
|
131 |
+
# Add pre-defined functions to Julia
|
132 |
+
for op_list in [binary_operators, unary_operators]:
|
133 |
+
for i in range(len(op_list)):
|
134 |
+
op = op_list[i]
|
135 |
+
if '(' not in op:
|
136 |
+
continue
|
137 |
+
|
138 |
+
def_hyperparams += op + "\n"
|
139 |
+
first_non_char = [
|
140 |
+
j for j in range(len(op))
|
141 |
+
if not (op[j].isalpha() or op[j].isdigit())][0]
|
142 |
+
function_name = op[:first_non_char]
|
143 |
+
op_list[i] = function_name
|
144 |
+
print(op_list)
|
145 |
+
|
146 |
+
def_hyperparams += f"""include("{pkg_directory}/operators.jl")
|
147 |
const binops = {'[' + ', '.join(binary_operators) + ']'}
|
148 |
const unaops = {'[' + ', '.join(unary_operators) + ']'}
|
149 |
const ns=10;
|
|
|
172 |
]
|
173 |
"""
|
174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
if X.shape[1] == 1:
|
176 |
X_str = 'transpose([' + str(X.tolist()).replace(']', '').replace(',', '').replace('[', '') + '])'
|
177 |
else:
|