Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import argparse | |
import logging | |
from pathlib import Path | |
import gradio as gr | |
import pandas as pd | |
import torch | |
from temps.temps_arch import EncoderPhotometry, MeasureZ | |
from temps.temps import TempsModule | |
logger = logging.getLogger(__name__) | |
# Define the prediction function that will be called by Gradio | |
def predict(input_file_path: Path): | |
model_path = Path("models/") | |
logger.info("Loading data and converting fluxes to colors") | |
# Load the input data file (CSV) | |
try: | |
fluxes = pd.read_csv(input_file_path, sep=",", header=0) | |
except Exception as e: | |
logger.error(f"Error loading input file: {e}") | |
return f"Error loading file: {e}" | |
# Assuming that the model expects "colors" as input | |
colors = fluxes.iloc[:, :-1] / fluxes.iloc[:, 1:] | |
logger.info("Loading model...") | |
# Load the neural network models from the given model path | |
nn_features = EncoderPhotometry() | |
nn_z = MeasureZ(num_gauss=6) | |
try: | |
nn_features.load_state_dict( | |
torch.load(model_path / "modelF.pt", map_location=torch.device("cpu")) | |
) | |
nn_z.load_state_dict( | |
torch.load(model_path / "modelZ.pt", map_location=torch.device("cpu")) | |
) | |
except Exception as e: | |
logger.error(f"Error loading model: {e}") | |
return f"Error loading model: {e}" | |
temps_module = TempsModule(nn_features, nn_z) | |
# Run predictions | |
try: | |
z, pz, odds = temps_module.get_pz( | |
input_data=torch.Tensor(colors.values), return_pz=True, return_flag=True | |
) | |
except Exception as e: | |
logger.error(f"Error during prediction: {e}") | |
return f"Error during prediction: {e}" | |
# Return the predictions as a dictionary | |
result = { | |
"redshift (z)": z.tolist(), | |
"posterior (pz)": pz.tolist(), | |
"odds": odds.tolist(), | |
} | |
return result | |
def get_args() -> argparse.Namespace: | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--log-level", | |
default="INFO", | |
choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"], | |
) | |
parser.add_argument( | |
"--server-address", # Changed from server-name | |
default="127.0.0.1", # Changed default to match launch | |
type=str, | |
) | |
parser.add_argument( | |
"--input-file-path", | |
type=Path, | |
help="Path to the input CSV file", | |
) | |
parser.add_argument( | |
"--port", | |
type=int, | |
default=7860, | |
) | |
return parser.parse_args() | |
interface = gr.Interface( | |
fn=predict, | |
inputs=[gr.File(label="Upload CSV file", file_types=[".csv"], type="filepath")], | |
outputs=[gr.JSON(label="Predictions")], | |
title="Photometric Redshift Prediction", | |
description="Upload a CSV file containing flux measurements to get redshift predictions, posterior probabilities, and odds.", | |
) | |
if __name__ == "__main__": | |
args = get_args() | |
logging.basicConfig(level=args.log_level) | |
logger.info(f"Starting server on {args.server_address}:{args.port}") | |
interface.launch( | |
server_name=args.server_address, | |
server_port=args.port, | |
share=True, | |
debug=True, # Add debug mode | |
show_error=True, # Show detailed error messages | |
) | |