Spaces:
Runtime error
Runtime error
Laura Cabayol Garcia
commited on
Commit
·
578b609
1
Parent(s):
d6eb94f
debugg app.py
Browse files
app.py
CHANGED
@@ -1,77 +1,73 @@
|
|
1 |
-
# Add this at the very top of the file, before other imports
|
2 |
-
print("Script starting...")
|
3 |
-
import sys
|
4 |
-
|
5 |
-
print(f"Python version: {sys.version}")
|
6 |
-
|
7 |
from __future__ import annotations # This should actually be the first import
|
8 |
import argparse
|
9 |
import logging
|
10 |
from pathlib import Path
|
11 |
|
12 |
-
print("Starting to import libraries...")
|
13 |
import gradio as gr
|
14 |
import pandas as pd
|
15 |
import torch
|
16 |
-
|
17 |
-
print("Libraries imported successfully")
|
18 |
|
19 |
from temps.temps_arch import EncoderPhotometry, MeasureZ
|
20 |
from temps.temps import TempsModule
|
|
|
21 |
|
22 |
logger = logging.getLogger(__name__)
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def predict(input_file_path: Path):
|
26 |
-
|
|
|
|
|
|
|
27 |
|
28 |
-
|
29 |
|
30 |
-
#
|
31 |
try:
|
32 |
fluxes = pd.read_csv(input_file_path, sep=",", header=0)
|
33 |
except Exception as e:
|
34 |
logger.error(f"Error loading input file: {e}")
|
35 |
return f"Error loading file: {e}"
|
36 |
|
37 |
-
|
38 |
-
colors = fluxes.iloc[:, :-1] / fluxes.iloc[:, 1:]
|
39 |
-
|
40 |
-
logger.info("Loading model...")
|
41 |
-
|
42 |
-
# Load the neural network models from the given model path
|
43 |
-
nn_features = EncoderPhotometry()
|
44 |
-
nn_z = MeasureZ(num_gauss=6)
|
45 |
-
|
46 |
-
try:
|
47 |
-
nn_features.load_state_dict(
|
48 |
-
torch.load(model_path / "modelF.pt", map_location=torch.device("cpu"))
|
49 |
-
)
|
50 |
-
nn_z.load_state_dict(
|
51 |
-
torch.load(model_path / "modelZ.pt", map_location=torch.device("cpu"))
|
52 |
-
)
|
53 |
-
except Exception as e:
|
54 |
-
logger.error(f"Error loading model: {e}")
|
55 |
-
return f"Error loading model: {e}"
|
56 |
|
57 |
temps_module = TempsModule(nn_features, nn_z)
|
58 |
|
59 |
-
# Run predictions
|
60 |
try:
|
61 |
z, pz, odds = temps_module.get_pz(
|
62 |
-
input_data=torch.Tensor(colors
|
63 |
)
|
64 |
except Exception as e:
|
65 |
logger.error(f"Error during prediction: {e}")
|
66 |
return f"Error during prediction: {e}"
|
67 |
|
68 |
-
|
69 |
-
result = {
|
70 |
-
"redshift (z)": z.tolist(),
|
71 |
-
"posterior (pz)": pz.tolist(),
|
72 |
-
"odds": odds.tolist(),
|
73 |
-
}
|
74 |
-
return result
|
75 |
|
76 |
|
77 |
def get_args() -> argparse.Namespace:
|
@@ -115,11 +111,24 @@ interface = gr.Interface(
|
|
115 |
if __name__ == "__main__":
|
116 |
args = get_args()
|
117 |
logging.basicConfig(level=args.log_level)
|
118 |
-
|
119 |
-
interface
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from __future__ import annotations # This should actually be the first import
|
2 |
import argparse
|
3 |
import logging
|
4 |
from pathlib import Path
|
5 |
|
|
|
6 |
import gradio as gr
|
7 |
import pandas as pd
|
8 |
import torch
|
9 |
+
import os
|
|
|
10 |
|
11 |
from temps.temps_arch import EncoderPhotometry, MeasureZ
|
12 |
from temps.temps import TempsModule
|
13 |
+
from temps.constants import PROJ_ROOT
|
14 |
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
17 |
+
|
18 |
+
def get_model_path() -> Path:
|
19 |
+
"""Get the appropriate model path for both local and Docker/HF environments"""
|
20 |
+
if os.environ.get("SPACE_ID"):
|
21 |
+
# HuggingFace Spaces - models will be in the root directory
|
22 |
+
logger.info("Running on HuggingFace Spaces")
|
23 |
+
return Path("data/models") # Absolute path to models in HF Spaces
|
24 |
+
else:
|
25 |
+
return PROJ_ROOT / "data/models/"
|
26 |
+
|
27 |
+
|
28 |
+
def load_models(model_path: Path):
|
29 |
+
logger.info(f"Loading models from {model_path}")
|
30 |
+
nn_features = EncoderPhotometry()
|
31 |
+
nn_z = MeasureZ(num_gauss=6)
|
32 |
+
|
33 |
+
nn_features.load_state_dict(
|
34 |
+
torch.load(model_path / "modelF_DA.pt", map_location=torch.device("cpu"))
|
35 |
+
)
|
36 |
+
nn_z.load_state_dict(
|
37 |
+
torch.load(model_path / "modelZ_DA.pt", map_location=torch.device("cpu"))
|
38 |
+
)
|
39 |
+
|
40 |
+
return nn_features, nn_z
|
41 |
+
|
42 |
+
|
43 |
def predict(input_file_path: Path):
|
44 |
+
global LOADED_MODELS
|
45 |
+
if LOADED_MODELS is None:
|
46 |
+
logger.error("Models not loaded!")
|
47 |
+
return "Error: Models not initialized"
|
48 |
|
49 |
+
nn_features, nn_z = LOADED_MODELS
|
50 |
|
51 |
+
# Rest of your predict function, but use the pre-loaded models
|
52 |
try:
|
53 |
fluxes = pd.read_csv(input_file_path, sep=",", header=0)
|
54 |
except Exception as e:
|
55 |
logger.error(f"Error loading input file: {e}")
|
56 |
return f"Error loading file: {e}"
|
57 |
|
58 |
+
colors = fluxes.values[:, :-1] / fluxes.values[:, 1:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
temps_module = TempsModule(nn_features, nn_z)
|
61 |
|
|
|
62 |
try:
|
63 |
z, pz, odds = temps_module.get_pz(
|
64 |
+
input_data=torch.Tensor(colors), return_pz=True, return_flag=True
|
65 |
)
|
66 |
except Exception as e:
|
67 |
logger.error(f"Error during prediction: {e}")
|
68 |
return f"Error during prediction: {e}"
|
69 |
|
70 |
+
return (z.tolist(),)
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
|
73 |
def get_args() -> argparse.Namespace:
|
|
|
111 |
if __name__ == "__main__":
|
112 |
args = get_args()
|
113 |
logging.basicConfig(level=args.log_level)
|
114 |
+
|
115 |
+
# Load models before creating the interface
|
116 |
+
try:
|
117 |
+
# model_path = PROJ_ROOT / "data/models/"
|
118 |
+
model_path = get_model_path()
|
119 |
+
logger.info("Loading models...")
|
120 |
+
LOADED_MODELS = load_models(model_path)
|
121 |
+
logger.info("Models loaded successfully")
|
122 |
+
except Exception as e:
|
123 |
+
logger.error(f"Failed to load models: {e}")
|
124 |
+
raise
|
125 |
+
|
126 |
+
interface = gr.Interface(
|
127 |
+
fn=predict,
|
128 |
+
inputs=[gr.File(label="Upload CSV file", file_types=[".csv"], type="filepath")],
|
129 |
+
outputs=[gr.JSON(label="Predictions")],
|
130 |
+
title="Photometric Redshift Prediction",
|
131 |
+
description="Upload a CSV file containing flux measurements to get redshift predictions.",
|
132 |
)
|
133 |
+
|
134 |
+
interface.launch(show_error=True, debug=True)
|