MilesCranmer commited on
Commit
62d539c
1 Parent(s): db11d11

Clean up anti-patterns

Browse files
Files changed (1) hide show
  1. pysr/sr.py +6 -9
pysr/sr.py CHANGED
@@ -289,14 +289,14 @@ def pysr(
289
  variable_names = [f"x{i}" for i in range(X.shape[1])]
290
 
291
  if extra_jax_mappings is not None:
292
- for key, value in extra_jax_mappings:
293
  if not isinstance(value, str):
294
  raise NotImplementedError(
295
  "extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
296
  )
297
 
298
  if extra_torch_mappings is not None:
299
- for key, value in extra_jax_mappings:
300
  if not callable(value):
301
  raise NotImplementedError(
302
  "extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
@@ -797,8 +797,7 @@ def _handle_constraints(binary_operators, constraints, unary_operators, **kwargs
797
  def _create_inline_operators(binary_operators, unary_operators, **kwargs):
798
  def_hyperparams = ""
799
  for op_list in [binary_operators, unary_operators]:
800
- for i in range(len(op_list)):
801
- op = op_list[i]
802
  is_user_defined_operator = "(" in op
803
 
804
  if is_user_defined_operator:
@@ -806,8 +805,8 @@ def _create_inline_operators(binary_operators, unary_operators, **kwargs):
806
  # Cut off from the first non-alphanumeric char:
807
  first_non_char = [
808
  j
809
- for j in range(len(op))
810
- if not (op[j].isalpha() or op[j].isdigit())
811
  ][0]
812
  function_name = op[:first_non_char]
813
  op_list[i] = function_name
@@ -823,9 +822,7 @@ def _handle_feature_selection(
823
  X = X[:, selection]
824
 
825
  if use_custom_variable_names:
826
- variable_names = [
827
- variable_names[selection[i]] for i in range(len(selection))
828
- ]
829
  else:
830
  selection = None
831
  return X, variable_names, selection
 
289
  variable_names = [f"x{i}" for i in range(X.shape[1])]
290
 
291
  if extra_jax_mappings is not None:
292
+ for value in extra_jax_mappings.values():
293
  if not isinstance(value, str):
294
  raise NotImplementedError(
295
  "extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
296
  )
297
 
298
  if extra_torch_mappings is not None:
299
+ for value in extra_jax_mappings.values():
300
  if not callable(value):
301
  raise NotImplementedError(
302
  "extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
 
797
  def _create_inline_operators(binary_operators, unary_operators, **kwargs):
798
  def_hyperparams = ""
799
  for op_list in [binary_operators, unary_operators]:
800
+ for i, op in enumerate(op_list):
 
801
  is_user_defined_operator = "(" in op
802
 
803
  if is_user_defined_operator:
 
805
  # Cut off from the first non-alphanumeric char:
806
  first_non_char = [
807
  j
808
+ for j, char in enumerate(op)
809
+ if not (char.isalpha() or char.isdigit())
810
  ][0]
811
  function_name = op[:first_non_char]
812
  op_list[i] = function_name
 
822
  X = X[:, selection]
823
 
824
  if use_custom_variable_names:
825
+ variable_names = [variable_names[i] for i in selection]
 
 
826
  else:
827
  selection = None
828
  return X, variable_names, selection