Spaces:
Sleeping
Sleeping
from typing import Optional, Any | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from contextlib import asynccontextmanager | |
from joblib import load | |
from models.iris import Iris | |
class Model(BaseModel): | |
id: int | |
name: str | |
param_count: Optional[int] = None | |
_model : Optional[Any] = None | |
models = { | |
"0" : Model(id=0, name="CNN"), | |
"1" : Model(id=1, name="Transformer"), | |
"2" : Model(id=2, name="Iris"), | |
} | |
id_2_hosted_models = { | |
model.id : model for model in models.values() | |
} | |
model_names_2_id = { | |
model.name.lower() : model.id for model in models.values() | |
} | |
#TODO: fix this mess ^^ | |
ml_models = { | |
model.name : model for model in models.values() | |
} | |
async def lifespan(app: FastAPI): | |
# Load the ML model | |
ml_models["Iris"]._model = load('models/iris_v1.joblib') | |
yield | |
# Clean up the ML models and release the resources | |
ml_models.clear() | |
################################################################ | |
app = FastAPI( | |
title="ML Repository API", | |
description="API for getting predictions from hosted ML models.", | |
version="0.0.1", | |
lifespan=lifespan) | |
def greet_json(): | |
return {"Hello World": "Welcome to my ML Repository API!"} | |
def list_models(): | |
"List all the hosted models." | |
return models | |
def get_by_id(model_id: int): | |
"Get a specific model by its ID." | |
if model_id not in id_2_hosted_models: | |
raise HTTPException(status_code=404, detail=f"Model with 'id={model_id}' not found") | |
return id_2_hosted_models[model_id] | |
def get_by_name(model_name: str): | |
"Get a specific model by its name." | |
model_name = model_name.lower() | |
if model_name not in model_names_2_id: | |
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") | |
return id_2_hosted_models[model_names_2_id[model_name]] | |
async def get_prediction(model_name: str, iris: Iris): | |
if model_name.lower() != "iris": | |
raise HTTPException(status_code=501, detail="Not implemented yet.") | |
data = dict(iris)['data'] | |
prediction = ml_models["Iris"]._model.predict(data).tolist() | |
log_probs = ml_models["Iris"]._model.predict_proba(data).tolist() | |
return {"predictions": prediction, | |
"log_probs": log_probs} | |