Laura Cabayol Garcia commited on
Commit
578b609
·
1 Parent(s): d6eb94f

debugg app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -48
app.py CHANGED
@@ -1,77 +1,73 @@
1
- # Add this at the very top of the file, before other imports
2
- print("Script starting...")
3
- import sys
4
-
5
- print(f"Python version: {sys.version}")
6
-
7
  from __future__ import annotations # This should actually be the first import
8
  import argparse
9
  import logging
10
  from pathlib import Path
11
 
12
- print("Starting to import libraries...")
13
  import gradio as gr
14
  import pandas as pd
15
  import torch
16
-
17
- print("Libraries imported successfully")
18
 
19
  from temps.temps_arch import EncoderPhotometry, MeasureZ
20
  from temps.temps import TempsModule
 
21
 
22
  logger = logging.getLogger(__name__)
23
 
24
- # Define the prediction function that will be called by Gradio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def predict(input_file_path: Path):
26
- model_path = Path("models/")
 
 
 
27
 
28
- logger.info("Loading data and converting fluxes to colors")
29
 
30
- # Load the input data file (CSV)
31
  try:
32
  fluxes = pd.read_csv(input_file_path, sep=",", header=0)
33
  except Exception as e:
34
  logger.error(f"Error loading input file: {e}")
35
  return f"Error loading file: {e}"
36
 
37
- # Assuming that the model expects "colors" as input
38
- colors = fluxes.iloc[:, :-1] / fluxes.iloc[:, 1:]
39
-
40
- logger.info("Loading model...")
41
-
42
- # Load the neural network models from the given model path
43
- nn_features = EncoderPhotometry()
44
- nn_z = MeasureZ(num_gauss=6)
45
-
46
- try:
47
- nn_features.load_state_dict(
48
- torch.load(model_path / "modelF.pt", map_location=torch.device("cpu"))
49
- )
50
- nn_z.load_state_dict(
51
- torch.load(model_path / "modelZ.pt", map_location=torch.device("cpu"))
52
- )
53
- except Exception as e:
54
- logger.error(f"Error loading model: {e}")
55
- return f"Error loading model: {e}"
56
 
57
  temps_module = TempsModule(nn_features, nn_z)
58
 
59
- # Run predictions
60
  try:
61
  z, pz, odds = temps_module.get_pz(
62
- input_data=torch.Tensor(colors.values), return_pz=True, return_flag=True
63
  )
64
  except Exception as e:
65
  logger.error(f"Error during prediction: {e}")
66
  return f"Error during prediction: {e}"
67
 
68
- # Return the predictions as a dictionary
69
- result = {
70
- "redshift (z)": z.tolist(),
71
- "posterior (pz)": pz.tolist(),
72
- "odds": odds.tolist(),
73
- }
74
- return result
75
 
76
 
77
  def get_args() -> argparse.Namespace:
@@ -115,11 +111,24 @@ interface = gr.Interface(
115
  if __name__ == "__main__":
116
  args = get_args()
117
  logging.basicConfig(level=args.log_level)
118
- logger.info(f"Starting server on {args.server_address}:{args.port}")
119
- interface.launch(
120
- server_name=args.server_address,
121
- server_port=args.port,
122
- share=True,
123
- debug=True,
124
- show_error=True,
 
 
 
 
 
 
 
 
 
 
 
125
  )
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations # This should actually be the first import
2
  import argparse
3
  import logging
4
  from pathlib import Path
5
 
 
6
  import gradio as gr
7
  import pandas as pd
8
  import torch
9
+ import os
 
10
 
11
  from temps.temps_arch import EncoderPhotometry, MeasureZ
12
  from temps.temps import TempsModule
13
+ from temps.constants import PROJ_ROOT
14
 
15
  logger = logging.getLogger(__name__)
16
 
17
+
18
+ def get_model_path() -> Path:
19
+ """Get the appropriate model path for both local and Docker/HF environments"""
20
+ if os.environ.get("SPACE_ID"):
21
+ # HuggingFace Spaces - models will be in the root directory
22
+ logger.info("Running on HuggingFace Spaces")
23
+ return Path("data/models") # Absolute path to models in HF Spaces
24
+ else:
25
+ return PROJ_ROOT / "data/models/"
26
+
27
+
28
+ def load_models(model_path: Path):
29
+ logger.info(f"Loading models from {model_path}")
30
+ nn_features = EncoderPhotometry()
31
+ nn_z = MeasureZ(num_gauss=6)
32
+
33
+ nn_features.load_state_dict(
34
+ torch.load(model_path / "modelF_DA.pt", map_location=torch.device("cpu"))
35
+ )
36
+ nn_z.load_state_dict(
37
+ torch.load(model_path / "modelZ_DA.pt", map_location=torch.device("cpu"))
38
+ )
39
+
40
+ return nn_features, nn_z
41
+
42
+
43
  def predict(input_file_path: Path):
44
+ global LOADED_MODELS
45
+ if LOADED_MODELS is None:
46
+ logger.error("Models not loaded!")
47
+ return "Error: Models not initialized"
48
 
49
+ nn_features, nn_z = LOADED_MODELS
50
 
51
+ # Rest of your predict function, but use the pre-loaded models
52
  try:
53
  fluxes = pd.read_csv(input_file_path, sep=",", header=0)
54
  except Exception as e:
55
  logger.error(f"Error loading input file: {e}")
56
  return f"Error loading file: {e}"
57
 
58
+ colors = fluxes.values[:, :-1] / fluxes.values[:, 1:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  temps_module = TempsModule(nn_features, nn_z)
61
 
 
62
  try:
63
  z, pz, odds = temps_module.get_pz(
64
+ input_data=torch.Tensor(colors), return_pz=True, return_flag=True
65
  )
66
  except Exception as e:
67
  logger.error(f"Error during prediction: {e}")
68
  return f"Error during prediction: {e}"
69
 
70
+ return (z.tolist(),)
 
 
 
 
 
 
71
 
72
 
73
  def get_args() -> argparse.Namespace:
 
111
  if __name__ == "__main__":
112
  args = get_args()
113
  logging.basicConfig(level=args.log_level)
114
+
115
+ # Load models before creating the interface
116
+ try:
117
+ # model_path = PROJ_ROOT / "data/models/"
118
+ model_path = get_model_path()
119
+ logger.info("Loading models...")
120
+ LOADED_MODELS = load_models(model_path)
121
+ logger.info("Models loaded successfully")
122
+ except Exception as e:
123
+ logger.error(f"Failed to load models: {e}")
124
+ raise
125
+
126
+ interface = gr.Interface(
127
+ fn=predict,
128
+ inputs=[gr.File(label="Upload CSV file", file_types=[".csv"], type="filepath")],
129
+ outputs=[gr.JSON(label="Predictions")],
130
+ title="Photometric Redshift Prediction",
131
+ description="Upload a CSV file containing flux measurements to get redshift predictions.",
132
  )
133
+
134
+ interface.launch(show_error=True, debug=True)