Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
62d539c
1
Parent(s):
db11d11
Clean up anti-patterns
Browse files- pysr/sr.py +6 -9
pysr/sr.py
CHANGED
@@ -289,14 +289,14 @@ def pysr(
|
|
289 |
variable_names = [f"x{i}" for i in range(X.shape[1])]
|
290 |
|
291 |
if extra_jax_mappings is not None:
|
292 |
-
for
|
293 |
if not isinstance(value, str):
|
294 |
raise NotImplementedError(
|
295 |
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
296 |
)
|
297 |
|
298 |
if extra_torch_mappings is not None:
|
299 |
-
for
|
300 |
if not callable(value):
|
301 |
raise NotImplementedError(
|
302 |
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
@@ -797,8 +797,7 @@ def _handle_constraints(binary_operators, constraints, unary_operators, **kwargs
|
|
797 |
def _create_inline_operators(binary_operators, unary_operators, **kwargs):
|
798 |
def_hyperparams = ""
|
799 |
for op_list in [binary_operators, unary_operators]:
|
800 |
-
for i in
|
801 |
-
op = op_list[i]
|
802 |
is_user_defined_operator = "(" in op
|
803 |
|
804 |
if is_user_defined_operator:
|
@@ -806,8 +805,8 @@ def _create_inline_operators(binary_operators, unary_operators, **kwargs):
|
|
806 |
# Cut off from the first non-alphanumeric char:
|
807 |
first_non_char = [
|
808 |
j
|
809 |
-
for j in
|
810 |
-
if not (
|
811 |
][0]
|
812 |
function_name = op[:first_non_char]
|
813 |
op_list[i] = function_name
|
@@ -823,9 +822,7 @@ def _handle_feature_selection(
|
|
823 |
X = X[:, selection]
|
824 |
|
825 |
if use_custom_variable_names:
|
826 |
-
variable_names = [
|
827 |
-
variable_names[selection[i]] for i in range(len(selection))
|
828 |
-
]
|
829 |
else:
|
830 |
selection = None
|
831 |
return X, variable_names, selection
|
|
|
289 |
variable_names = [f"x{i}" for i in range(X.shape[1])]
|
290 |
|
291 |
if extra_jax_mappings is not None:
|
292 |
+
for value in extra_jax_mappings.values():
|
293 |
if not isinstance(value, str):
|
294 |
raise NotImplementedError(
|
295 |
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
296 |
)
|
297 |
|
298 |
if extra_torch_mappings is not None:
|
299 |
+
for value in extra_jax_mappings.values():
|
300 |
if not callable(value):
|
301 |
raise NotImplementedError(
|
302 |
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
|
|
797 |
def _create_inline_operators(binary_operators, unary_operators, **kwargs):
|
798 |
def_hyperparams = ""
|
799 |
for op_list in [binary_operators, unary_operators]:
|
800 |
+
for i, op in enumerate(op_list):
|
|
|
801 |
is_user_defined_operator = "(" in op
|
802 |
|
803 |
if is_user_defined_operator:
|
|
|
805 |
# Cut off from the first non-alphanumeric char:
|
806 |
first_non_char = [
|
807 |
j
|
808 |
+
for j, char in enumerate(op)
|
809 |
+
if not (char.isalpha() or char.isdigit())
|
810 |
][0]
|
811 |
function_name = op[:first_non_char]
|
812 |
op_list[i] = function_name
|
|
|
822 |
X = X[:, selection]
|
823 |
|
824 |
if use_custom_variable_names:
|
825 |
+
variable_names = [variable_names[i] for i in selection]
|
|
|
|
|
826 |
else:
|
827 |
selection = None
|
828 |
return X, variable_names, selection
|