Spaces:
Runtime error
Runtime error
File size: 3,301 Bytes
d55c17f 6f437f1 d55c17f 0143a6f 14f20e4 d55c17f 668e440 d55c17f 668e440 d55c17f 668e440 d55c17f 668e440 d55c17f 668e440 d55c17f 6f437f1 4e1e7f9 d55c17f 6f437f1 668e440 6f437f1 668e440 6f437f1 d55c17f 14f20e4 4e1e7f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
from __future__ import annotations
import argparse
import logging
from pathlib import Path
import gradio as gr
import pandas as pd
import torch
from temps.temps_arch import EncoderPhotometry, MeasureZ
from temps.temps import TempsModule
logger = logging.getLogger(__name__)
# Define the prediction function that will be called by Gradio
def predict(input_file_path: Path):
model_path = Path("models/")
logger.info("Loading data and converting fluxes to colors")
# Load the input data file (CSV)
try:
fluxes = pd.read_csv(input_file_path, sep=",", header=0)
except Exception as e:
logger.error(f"Error loading input file: {e}")
return f"Error loading file: {e}"
# Assuming that the model expects "colors" as input
colors = fluxes.iloc[:, :-1] / fluxes.iloc[:, 1:]
logger.info("Loading model...")
# Load the neural network models from the given model path
nn_features = EncoderPhotometry()
nn_z = MeasureZ(num_gauss=6)
try:
nn_features.load_state_dict(
torch.load(model_path / "modelF.pt", map_location=torch.device("cpu"))
)
nn_z.load_state_dict(
torch.load(model_path / "modelZ.pt", map_location=torch.device("cpu"))
)
except Exception as e:
logger.error(f"Error loading model: {e}")
return f"Error loading model: {e}"
temps_module = TempsModule(nn_features, nn_z)
# Run predictions
try:
z, pz, odds = temps_module.get_pz(
input_data=torch.Tensor(colors.values), return_pz=True, return_flag=True
)
except Exception as e:
logger.error(f"Error during prediction: {e}")
return f"Error during prediction: {e}"
# Return the predictions as a dictionary
result = {
"redshift (z)": z.tolist(),
"posterior (pz)": pz.tolist(),
"odds": odds.tolist(),
}
return result
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--log-level",
default="INFO",
choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"],
)
parser.add_argument(
"--server-address", # Changed from server-name
default="127.0.0.1", # Changed default to match launch
type=str,
)
parser.add_argument(
"--input-file-path",
type=Path,
help="Path to the input CSV file",
)
parser.add_argument(
"--port",
type=int,
default=7860,
)
return parser.parse_args()
interface = gr.Interface(
fn=predict,
inputs=[gr.File(label="Upload CSV file", file_types=[".csv"], type="filepath")],
outputs=[gr.JSON(label="Predictions")],
title="Photometric Redshift Prediction",
description="Upload a CSV file containing flux measurements to get redshift predictions, posterior probabilities, and odds.",
)
if __name__ == "__main__":
args = get_args()
logging.basicConfig(level=args.log_level)
logger.info(f"Starting server on {args.server_address}:{args.port}")
interface.launch(
server_name=args.server_address,
server_port=args.port,
share=True,
debug=True, # Add debug mode
show_error=True, # Show detailed error messages
)
|