Spaces:
Runtime error
Runtime error
from __future__ import annotations # This should actually be the first import | |
import argparse | |
import logging | |
from pathlib import Path | |
import gradio as gr | |
import pandas as pd | |
import torch | |
import os | |
from temps.temps_arch import EncoderPhotometry, MeasureZ | |
from temps.temps import TempsModule | |
from temps.constants import PROJ_ROOT | |
logger = logging.getLogger(__name__) | |
def get_model_path() -> Path: | |
"""Get the appropriate model path for both local and Docker/HF environments""" | |
if os.environ.get("SPACE_ID"): | |
# HuggingFace Spaces - models will be in the root directory | |
logger.info("Running on HuggingFace Spaces") | |
return Path("data/models") # Absolute path to models in HF Spaces | |
else: | |
return PROJ_ROOT / "data/models/" | |
def load_models(model_path: Path): | |
logger.info(f"Loading models from {model_path}") | |
nn_features = EncoderPhotometry() | |
nn_z = MeasureZ(num_gauss=6) | |
nn_features.load_state_dict( | |
torch.load( | |
model_path / "modelF_DA.pt", | |
weights_only=True, | |
map_location=torch.device("cpu"), | |
) | |
) | |
nn_z.load_state_dict( | |
torch.load( | |
model_path / "modelZ_DA.pt", | |
weights_only=True, | |
map_location=torch.device("cpu"), | |
) | |
) | |
return nn_features, nn_z | |
def predict(input_file_path: Path): | |
global LOADED_MODELS | |
if LOADED_MODELS is None: | |
logger.error("Models not loaded!") | |
return "Error: Models not initialized" | |
nn_features, nn_z = LOADED_MODELS | |
# Rest of your predict function, but use the pre-loaded models | |
try: | |
fluxes = pd.read_csv(input_file_path, sep=",", header=0) | |
except Exception as e: | |
logger.error(f"Error loading input file: {e}") | |
return f"Error loading file: {e}" | |
colors = fluxes.values[:, :-1] / fluxes.values[:, 1:] | |
temps_module = TempsModule(nn_features, nn_z) | |
try: | |
z, pz, odds = temps_module.get_pz( | |
input_data=torch.Tensor(colors), return_pz=True, return_flag=True | |
) | |
except Exception as e: | |
logger.error(f"Error during prediction: {e}") | |
return f"Error during prediction: {e}" | |
return (z.tolist(),) | |
def get_args() -> argparse.Namespace: | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--log-level", | |
default="INFO", | |
choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"], | |
) | |
parser.add_argument( | |
"--server-address", # Changed from server-name | |
default="127.0.0.1", # Changed default to match launch | |
type=str, | |
) | |
parser.add_argument( | |
"--input-file-path", | |
type=Path, | |
help="Path to the input CSV file", | |
) | |
parser.add_argument( | |
"--port", | |
type=int, | |
default=7860, | |
) | |
return parser.parse_args() | |
interface = gr.Interface( | |
fn=predict, | |
inputs=[gr.File(label="Upload CSV file", file_types=[".csv"], type="filepath")], | |
outputs=[gr.JSON(label="Predictions")], | |
title="Photometric Redshift Prediction", | |
description="Upload a CSV file containing flux measurements to get redshift predictions, posterior probabilities, and odds.", | |
) | |
if __name__ == "__main__": | |
args = get_args() | |
logging.basicConfig(level=args.log_level) | |
# Load models before creating the interface | |
try: | |
# model_path = PROJ_ROOT / "data/models/" | |
model_path = get_model_path() | |
logger.info("Loading models...") | |
LOADED_MODELS = load_models(model_path) | |
logger.info("Models loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load models: {e}") | |
raise | |
interface.launch(server_name=args.server_name, server_port=args.server_port) | |