fastapi_test / app.py
André Fernandes
added predict method for a classic ML classifier model (wip)
48b647e
from typing import Optional, Any
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
from joblib import load
from models.iris import Iris
class Model(BaseModel):
id: int
name: str
param_count: Optional[int] = None
_model : Optional[Any] = None
models = {
"0" : Model(id=0, name="CNN"),
"1" : Model(id=1, name="Transformer"),
"2" : Model(id=2, name="Iris"),
}
id_2_hosted_models = {
model.id : model for model in models.values()
}
model_names_2_id = {
model.name.lower() : model.id for model in models.values()
}
#TODO: fix this mess ^^
ml_models = {
model.name : model for model in models.values()
}
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load the ML model
ml_models["Iris"]._model = load('models/iris_v1.joblib')
yield
# Clean up the ML models and release the resources
ml_models.clear()
################################################################
app = FastAPI(
title="ML Repository API",
description="API for getting predictions from hosted ML models.",
version="0.0.1",
lifespan=lifespan)
@app.get("/")
def greet_json():
return {"Hello World": "Welcome to my ML Repository API!"}
@app.get("/hosted")
def list_models():
"List all the hosted models."
return models
@app.get("/hosted/id/{model_id}")
def get_by_id(model_id: int):
"Get a specific model by its ID."
if model_id not in id_2_hosted_models:
raise HTTPException(status_code=404, detail=f"Model with 'id={model_id}' not found")
return id_2_hosted_models[model_id]
@app.get("/hosted/name/{model_name}")
def get_by_name(model_name: str):
"Get a specific model by its name."
model_name = model_name.lower()
if model_name not in model_names_2_id:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
return id_2_hosted_models[model_names_2_id[model_name]]
@app.post("/hosted/name/{model_name}/predict", tags=["Predictions"])
async def get_prediction(model_name: str, iris: Iris):
if model_name.lower() != "iris":
raise HTTPException(status_code=501, detail="Not implemented yet.")
data = dict(iris)['data']
prediction = ml_models["Iris"]._model.predict(data).tolist()
log_probs = ml_models["Iris"]._model.predict_proba(data).tolist()
return {"predictions": prediction,
"log_probs": log_probs}