lauracabayol commited on
Commit
d55c17f
·
1 Parent(s): 9424894
Files changed (1) hide show
  1. temps/app.py +127 -0
temps/app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import argparse
3
+ import logging
4
+ import sys
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ import gradio as gr
9
+ import pandas as pd
10
+ import torch
11
+ from huggingface_hub import snapshot_download
12
+
13
+ from temps.archive import Archive
14
+ from temps.temps_arch import EncoderPhotometry, MeasureZ
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Define the prediction function that will be called by Gradio
19
+ def predict(input_file_path: Path, model_path: Path):
20
+ logging.basicConfig(
21
+ stream=sys.stdout,
22
+ level=logging.INFO,
23
+ format="%(levelname)s:%(message)s",
24
+ )
25
+
26
+ logger.info("Loading data and converting fluxes to colors")
27
+
28
+ # Load the input data file (CSV)
29
+ try:
30
+ fluxes = pd.read_csv(input_file_path, sep=',', header=0)
31
+ except Exception as e:
32
+ logger.error(f"Error loading input file: {e}")
33
+ return f"Error loading file: {e}"
34
+
35
+ # Assuming that the model expects "colors" as input
36
+ colors = fluxes.iloc[:, :-1] / fluxes.iloc[:, 1:]
37
+
38
+ logger.info("Loading model...")
39
+
40
+ # Load the neural network models from the given model path
41
+ nn_features = EncoderPhotometry()
42
+ nn_z = MeasureZ(num_gauss=6)
43
+
44
+ try:
45
+ nn_features.load_state_dict(torch.load(model_path / 'modelF.pt', map_location=torch.device('cpu')))
46
+ nn_z.load_state_dict(torch.load(model_path / 'modelZ.pt', map_location=torch.device('cpu')))
47
+ except Exception as e:
48
+ logger.error(f"Error loading model: {e}")
49
+ return f"Error loading model: {e}"
50
+
51
+ temps_module = TempsModule(nn_features, nn_z)
52
+
53
+ # Run predictions
54
+ try:
55
+ z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(colors.values),
56
+ return_pz=True,
57
+ return_flag=True)
58
+ except Exception as e:
59
+ logger.error(f"Error during prediction: {e}")
60
+ return f"Error during prediction: {e}"
61
+
62
+ # Return the predictions as a dictionary
63
+ result = {
64
+ "redshift (z)": z.tolist(),
65
+ "posterior (pz)": pz.tolist(),
66
+ "odds": odds.tolist()
67
+ }
68
+ return result
69
+
70
+
71
+ # Gradio app
72
+ def main(args: Optional[argparse.Namespace] = None) -> None:
73
+ if args is None:
74
+ args = get_args()
75
+
76
+ # Define the Gradio interface
77
+ gr.Interface(
78
+ fn=predict, # the function that Gradio will call
79
+ inputs=[
80
+ gr.inputs.File(label="Upload your input CSV file"), # file input for the data
81
+ gr.inputs.Textbox(label="Model path", default=str(args.model_path)), # text input for model path
82
+ ],
83
+ outputs="json", # return the results as JSON
84
+ live=False,
85
+ title="Prediction App",
86
+ description="Upload a CSV file with your data to get predictions.",
87
+ ).launch(server_name=args.server_name, server_port=args.port)
88
+
89
+
90
+ def get_args() -> argparse.Namespace:
91
+ parser = argparse.ArgumentParser()
92
+
93
+ parser.add_argument(
94
+ "--log-level",
95
+ default="INFO",
96
+ choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"],
97
+ )
98
+
99
+ parser.add_argument(
100
+ "--server-name",
101
+ default="127.0.0.1",
102
+ type=str,
103
+ )
104
+
105
+ parser.add_argument(
106
+ "--input-file-path",
107
+ type=Path,
108
+ help="Path to the input CSV file",
109
+ )
110
+
111
+ parser.add_argument(
112
+ "--model-path",
113
+ type=Path,
114
+ help="Path to the model files",
115
+ )
116
+
117
+ parser.add_argument(
118
+ "--port",
119
+ type=int,
120
+ default=7860,
121
+ )
122
+
123
+ return parser.parse_args()
124
+
125
+
126
+ if __name__ == "__main__":
127
+ main()