Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import argparse | |
import logging | |
import sys | |
from pathlib import Path | |
import gradio as gr | |
import pandas as pd | |
import torch | |
from huggingface_hub import snapshot_download | |
from temps.archive import Archive | |
from temps.temps_arch import EncoderPhotometry, MeasureZ | |
logger = logging.getLogger(__name__) | |
# Define the prediction function that will be called by Gradio | |
def predict(input_file_path: Path): | |
model_path = Path("app/models/") | |
logger.info("Loading data and converting fluxes to colors") | |
# Load the input data file (CSV) | |
try: | |
fluxes = pd.read_csv(input_file_path, sep=',', header=0) | |
except Exception as e: | |
logger.error(f"Error loading input file: {e}") | |
return f"Error loading file: {e}" | |
# Assuming that the model expects "colors" as input | |
colors = fluxes.iloc[:, :-1] / fluxes.iloc[:, 1:] | |
logger.info("Loading model...") | |
# Load the neural network models from the given model path | |
nn_features = EncoderPhotometry() | |
nn_z = MeasureZ(num_gauss=6) | |
try: | |
nn_features.load_state_dict(torch.load(model_path / 'modelF.pt', map_location=torch.device('cpu'))) | |
nn_z.load_state_dict(torch.load(model_path / 'modelZ.pt', map_location=torch.device('cpu'))) | |
except Exception as e: | |
logger.error(f"Error loading model: {e}") | |
return f"Error loading model: {e}" | |
temps_module = TempsModule(nn_features, nn_z) | |
# Run predictions | |
try: | |
z, pz, odds = temps_module.get_pz(input_data=torch.Tensor(colors.values), | |
return_pz=True, | |
return_flag=True) | |
except Exception as e: | |
logger.error(f"Error during prediction: {e}") | |
return f"Error during prediction: {e}" | |
# Return the predictions as a dictionary | |
result = { | |
"redshift (z)": z.tolist(), | |
"posterior (pz)": pz.tolist(), | |
"odds": odds.tolist() | |
} | |
return result | |
# Gradio app | |
def main(args=None) -> None: | |
if args is None: | |
args = get_args() | |
# Define the Gradio interface | |
gr.Interface( | |
fn=predict, # the function that Gradio will call | |
inputs=[ | |
gr.File(label="Upload your input CSV file"), # file input for the data | |
], | |
outputs="json", # return the results as JSON | |
live=False, | |
title="Prediction App", | |
description="Upload a CSV file with your data to get predictions.", | |
).launch(server_name=args.server_name, server_port=args.port, share=True) | |
def get_args() -> argparse.Namespace: | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--log-level", | |
default="INFO", | |
choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"], | |
) | |
parser.add_argument( | |
"--server-name", | |
default="127.0.0.1", | |
type=str, | |
) | |
parser.add_argument( | |
"--input-file-path", | |
type=Path, | |
help="Path to the input CSV file", | |
) | |
parser.add_argument( | |
"--port", | |
type=int, | |
default=7860, | |
) | |
return parser.parse_args() | |
if __name__ == "__main__": | |
main() | |