Laura Cabayol Garcia commited on
Commit
ecbaf8f
·
1 Parent(s): ca383de
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -31,10 +31,18 @@ def load_models(model_path: Path):
31
  nn_z = MeasureZ(num_gauss=6)
32
 
33
  nn_features.load_state_dict(
34
- torch.load(model_path / "modelF_DA.pt", map_location=torch.device("cpu"))
 
 
 
 
35
  )
36
  nn_z.load_state_dict(
37
- torch.load(model_path / "modelZ_DA.pt", map_location=torch.device("cpu"))
 
 
 
 
38
  )
39
 
40
  return nn_features, nn_z
@@ -123,12 +131,4 @@ if __name__ == "__main__":
123
  logger.error(f"Failed to load models: {e}")
124
  raise
125
 
126
- interface = gr.Interface(
127
- fn=predict,
128
- inputs=[gr.File(label="Upload CSV file", file_types=[".csv"], type="filepath")],
129
- outputs=[gr.JSON(label="Predictions")],
130
- title="Photometric Redshift Prediction",
131
- description="Upload a CSV file containing flux measurements to get redshift predictions.",
132
- )
133
-
134
- interface.launch(show_error=True, debug=True)
 
31
  nn_z = MeasureZ(num_gauss=6)
32
 
33
  nn_features.load_state_dict(
34
+ torch.load(
35
+ model_path / "modelF_DA.pt",
36
+ weights_only=True,
37
+ map_location=torch.device("cpu"),
38
+ )
39
  )
40
  nn_z.load_state_dict(
41
+ torch.load(
42
+ model_path / "modelZ_DA.pt",
43
+ weights_only=True,
44
+ map_location=torch.device("cpu"),
45
+ )
46
  )
47
 
48
  return nn_features, nn_z
 
131
  logger.error(f"Failed to load models: {e}")
132
  raise
133
 
134
+ interface.launch()