lauracabayol commited on
Commit
ca6b4b2
·
1 Parent(s): 4f50b0c
Files changed (1) hide show
  1. temps/temps.py +2 -2
temps/temps.py CHANGED
@@ -190,8 +190,8 @@ class TempsModule:
190
  def get_pz(self, input_data, return_pz=True, return_flag=True, return_odds=False):
191
  """Get the predicted z values and their uncertainties."""
192
  logger.info("Predicting photo-z for the input galaxies...")
193
- self.model_z.eval()
194
- self.model_f.eval()
195
 
196
  input_data = input_data.to(self.device)
197
  features = self.model_f(input_data)
 
190
  def get_pz(self, input_data, return_pz=True, return_flag=True, return_odds=False):
191
  """Get the predicted z values and their uncertainties."""
192
  logger.info("Predicting photo-z for the input galaxies...")
193
+ self.model_z.eval().to(self.device)
194
+ self.model_f.eval().to(self.device)
195
 
196
  input_data = input_data.to(self.device)
197
  features = self.model_f(input_data)