Spaces:
Runtime error
Runtime error
"""Server that will listen for GET requests from the client.""" | |
from fastapi import FastAPI | |
from joblib import load | |
from concrete.ml.deployment import FHEModelServer | |
from pydantic import BaseModel | |
import base64 | |
from pathlib import Path | |
current_dir = Path(__file__).parent | |
# Initialize an instance of FastAPI | |
app = FastAPI() | |
def root(): | |
""" | |
Root endpoint of the health prediction API. | |
Returns: | |
dict: The welcome message. | |
""" | |
return {"message": "Welcome to your disease prediction with FHE!"} | |
print(Path.joinpath(current_dir, "fhe_model")) | |
from glob import glob | |
print(glob(f'{current_dir}/fhe_model/*')) | |
# Load the model | |
fhe_model = FHEModelServer( | |
Path.joinpath(current_dir, "fhe_model") | |
) | |
print(fhe_model) | |
print('1111', current_dir) | |
class PredictRequest(BaseModel): | |
evaluation_key: str | |
encrypted_encoding: str | |
# Define the default route | |
def root(): | |
return {"message": "Welcome to Your ClairVault!"} | |
def predict(query: PredictRequest): | |
encrypted_encoding = base64.b64decode(query.encrypted_encoding) | |
evaluation_key = base64.b64decode(query.evaluation_key) | |
prediction = fhe_model.run(encrypted_encoding, evaluation_key) | |
# Encode base64 the prediction | |
encoded_prediction = base64.b64encode(prediction).decode() | |
return {"encrypted_prediction": encoded_prediction} | |
# if __name__ == "__main__": | |
# uvicorn.run(app, host="0.0.0.0", port=3000) | |