MilesCranmer commited on
Commit
da5e3e7
1 Parent(s): e635a4f

Fix bug with 1 feature; fixes #3

Browse files
Files changed (1) hide show
  1. pysr/sr.py +4 -1
pysr/sr.py CHANGED
@@ -151,7 +151,10 @@ const mutationWeights = [
151
  assert len(weights.shape) == 1
152
  assert X.shape[0] == weights.shape[0]
153
 
154
- X_str = str(X.tolist()).replace('],', '];').replace(',', '')
 
 
 
155
  y_str = str(y.tolist())
156
 
157
  def_datasets = """const X = convert(Array{Float32, 2}, """f"{X_str})""""
 
151
  assert len(weights.shape) == 1
152
  assert X.shape[0] == weights.shape[0]
153
 
154
+ if X.shape[1] == 1:
155
+ X_str = 'transpose([' + str(X.tolist()).replace(']', '').replace(',', '').replace('[', '') + '])'
156
+ else:
157
+ X_str = str(X.tolist()).replace('],', '];').replace(',', '')
158
  y_str = str(y.tolist())
159
 
160
  def_datasets = """const X = convert(Array{Float32, 2}, """f"{X_str})""""