Laura Cabayol Garcia commited on
Commit
6f437f1
·
1 Parent(s): 51b2703
Files changed (1) hide show
  1. app.py +24 -22
app.py CHANGED
@@ -11,6 +11,7 @@ from huggingface_hub import snapshot_download
11
 
12
  from temps.archive import Archive
13
  from temps.temps_arch import EncoderPhotometry, MeasureZ
 
14
 
15
  logger = logging.getLogger(__name__)
16
 
@@ -62,25 +63,6 @@ def predict(input_file_path: Path):
62
  }
63
  return result
64
 
65
-
66
- # Gradio app
67
- def main(args=None) -> None:
68
- if args is None:
69
- args = get_args()
70
-
71
- # Define the Gradio interface
72
- gr.Interface(
73
- fn=predict, # the function that Gradio will call
74
- inputs=[
75
- gr.File(label="Upload your input CSV file"), # file input for the data
76
- ],
77
- outputs="json", # return the results as JSON
78
- live=False,
79
- title="Prediction App",
80
- description="Upload a CSV file with your data to get predictions.",
81
- ).launch(server_name=args.server_name, server_port=args.port, share=True)
82
-
83
-
84
  def get_args() -> argparse.Namespace:
85
  parser = argparse.ArgumentParser()
86
 
@@ -91,8 +73,8 @@ def get_args() -> argparse.Namespace:
91
  )
92
 
93
  parser.add_argument(
94
- "--server-name",
95
- default="127.0.0.1",
96
  type=str,
97
  )
98
 
@@ -111,5 +93,25 @@ def get_args() -> argparse.Namespace:
111
  return parser.parse_args()
112
 
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  if __name__ == "__main__":
115
- main()
 
 
 
 
 
11
 
12
  from temps.archive import Archive
13
  from temps.temps_arch import EncoderPhotometry, MeasureZ
14
+ from temps.temps import TempsModule
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
63
  }
64
  return result
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def get_args() -> argparse.Namespace:
67
  parser = argparse.ArgumentParser()
68
 
 
73
  )
74
 
75
  parser.add_argument(
76
+ "--server-address", # Changed from server-name
77
+ default="0.0.0.0", # Changed default to match launch
78
  type=str,
79
  )
80
 
 
93
  return parser.parse_args()
94
 
95
 
96
+ interface = gr.Interface(
97
+ fn=predict,
98
+ inputs=[
99
+ gr.File(
100
+ label="Upload CSV file",
101
+ file_types=[".csv"],
102
+ type="filepath"
103
+ )
104
+ ],
105
+ outputs=[
106
+ gr.JSON(label="Predictions")
107
+ ],
108
+ title="Photometric Redshift Prediction",
109
+ description="Upload a CSV file containing flux measurements to get redshift predictions, posterior probabilities, and odds."
110
+ )
111
+
112
  if __name__ == "__main__":
113
+ interface.launch(
114
+ server_name="0.0.0.0",
115
+ server_port=7860,
116
+ share=True
117
+ )