MilesCranmer commited on
Commit
aadb328
1 Parent(s): 18afca5

Add ability to pass strings defining operators

Browse files
Files changed (1) hide show
  1. pysr/sr.py +29 -8
pysr/sr.py CHANGED
@@ -92,6 +92,17 @@ def pysr(X=None, y=None, weights=None, threads=4,
92
 
93
  """
94
 
 
 
 
 
 
 
 
 
 
 
 
95
  rand_string = f'{"".join([str(np.random.rand())[2] for i in range(20)])}'
96
 
97
  if isinstance(binary_operators, str): binary_operators = [binary_operators]
@@ -115,7 +126,24 @@ def pysr(X=None, y=None, weights=None, threads=4,
115
 
116
  pkg_directory = '/'.join(__file__.split('/')[:-2] + ['julia'])
117
 
118
- def_hyperparams = f"""include("{pkg_directory}/operators.jl")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  const binops = {'[' + ', '.join(binary_operators) + ']'}
120
  const unaops = {'[' + ', '.join(unary_operators) + ']'}
121
  const ns=10;
@@ -144,13 +172,6 @@ const mutationWeights = [
144
  ]
145
  """
146
 
147
- assert len(X.shape) == 2
148
- assert len(y.shape) == 1
149
- assert X.shape[0] == y.shape[0]
150
- if weights is not None:
151
- assert len(weights.shape) == 1
152
- assert X.shape[0] == weights.shape[0]
153
-
154
  if X.shape[1] == 1:
155
  X_str = 'transpose([' + str(X.tolist()).replace(']', '').replace(',', '').replace('[', '') + '])'
156
  else:
 
92
 
93
  """
94
 
95
+ # Check for potential errors before they happen
96
+ assert len(binary_operators) > 0
97
+ assert len(unary_operators) > 0
98
+ assert len(X.shape) == 2
99
+ assert len(y.shape) == 1
100
+ assert X.shape[0] == y.shape[0]
101
+ if weights is not None:
102
+ assert len(weights.shape) == 1
103
+ assert X.shape[0] == weights.shape[0]
104
+
105
+
106
  rand_string = f'{"".join([str(np.random.rand())[2] for i in range(20)])}'
107
 
108
  if isinstance(binary_operators, str): binary_operators = [binary_operators]
 
126
 
127
  pkg_directory = '/'.join(__file__.split('/')[:-2] + ['julia'])
128
 
129
+ def_hyperparams = ""
130
+
131
+ # Add pre-defined functions to Julia
132
+ for op_list in [binary_operators, unary_operators]:
133
+ for i in range(len(op_list)):
134
+ op = op_list[i]
135
+ if '(' not in op:
136
+ continue
137
+
138
+ def_hyperparams += op + "\n"
139
+ first_non_char = [
140
+ j for j in range(len(op))
141
+ if not (op[j].isalpha() or op[j].isdigit())][0]
142
+ function_name = op[:first_non_char]
143
+ op_list[i] = function_name
144
+ print(op_list)
145
+
146
+ def_hyperparams += f"""include("{pkg_directory}/operators.jl")
147
  const binops = {'[' + ', '.join(binary_operators) + ']'}
148
  const unaops = {'[' + ', '.join(unary_operators) + ']'}
149
  const ns=10;
 
172
  ]
173
  """
174
 
 
 
 
 
 
 
 
175
  if X.shape[1] == 1:
176
  X_str = 'transpose([' + str(X.tolist()).replace(']', '').replace(',', '').replace('[', '') + '])'
177
  else: