MilesCranmer commited on
Commit
5617815
1 Parent(s): b5b74c3

Make y flat if only one output feature

Browse files
Files changed (1) hide show
  1. pysr/sr.py +5 -4
pysr/sr.py CHANGED
@@ -278,12 +278,13 @@ def pysr(X=None, y=None, weights=None,
278
  if X is None:
279
  X, y = _using_test_input(X, test, y)
280
 
281
- if len(y.shape) == 2:
282
- multioutput = True
283
- nout = y.shape[1]
284
- elif len(y.shape) == 1:
285
  multioutput = False
286
  nout = 1
 
 
 
 
287
  else:
288
  raise NotImplementedError("y shape not supported!")
289
 
 
278
  if X is None:
279
  X, y = _using_test_input(X, test, y)
280
 
281
+ if len(y.shape) == 1 or (len(y.shape) == 2 and y.shape[1] == 1):
 
 
 
282
  multioutput = False
283
  nout = 1
284
+ y = y.reshape(-1)
285
+ elif len(y.shape) == 2:
286
+ multioutput = True
287
+ nout = y.shape[1]
288
  else:
289
  raise NotImplementedError("y shape not supported!")
290