TEMPS / app.py
Laura Cabayol Garcia
HF API
6f437f1
raw
history blame
3.24 kB
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
import gradio as gr
import pandas as pd
import torch
from huggingface_hub import snapshot_download
from temps.archive import Archive
from temps.temps_arch import EncoderPhotometry, MeasureZ
from temps.temps import TempsModule
logger = logging.getLogger(__name__)
# Define the prediction function that will be called by Gradio
def predict(input_file_path: Path):
model_path = Path("app/models/")
logger.info("Loading data and converting fluxes to colors")
# Load the input data file (CSV)
try:
fluxes = pd.read_csv(input_file_path, sep=',', header=0)
except Exception as e:
logger.error(f"Error loading input file: {e}")
return f"Error loading file: {e}"
# Assuming that the model expects "colors" as input
colors = fluxes.iloc[:, :-1] / fluxes.iloc[:, 1:]
logger.info("Loading model...")
# Load the neural network models from the given model path
nn_features = EncoderPhotometry()
nn_z = MeasureZ(num_gauss=6)
try:
nn_features.load_state_dict(torch.load(model_path / 'modelF.pt', map_location=torch.device('cpu')))
nn_z.load_state_dict(torch.load(model_path / 'modelZ.pt', map_location=torch.device('cpu')))
except Exception as e:
logger.error(f"Error loading model: {e}")
return f"Error loading model: {e}"
temps_module = TempsModule(nn_features, nn_z)
# Run predictions
try:
z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(colors.values),
return_pz=True,
return_flag=True)
except Exception as e:
logger.error(f"Error during prediction: {e}")
return f"Error during prediction: {e}"
# Return the predictions as a dictionary
result = {
"redshift (z)": z.tolist(),
"posterior (pz)": pz.tolist(),
"odds": odds.tolist()
}
return result
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--log-level",
default="INFO",
choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"],
)
parser.add_argument(
"--server-address", # Changed from server-name
default="0.0.0.0", # Changed default to match launch
type=str,
)
parser.add_argument(
"--input-file-path",
type=Path,
help="Path to the input CSV file",
)
parser.add_argument(
"--port",
type=int,
default=7860,
)
return parser.parse_args()
interface = gr.Interface(
fn=predict,
inputs=[
gr.File(
label="Upload CSV file",
file_types=[".csv"],
type="filepath"
)
],
outputs=[
gr.JSON(label="Predictions")
],
title="Photometric Redshift Prediction",
description="Upload a CSV file containing flux measurements to get redshift predictions, posterior probabilities, and odds."
)
if __name__ == "__main__":
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)