File size: 2,591 Bytes
743fc77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

"""Deployment server.

Routes:
    - Get client.zip
    - Add a key
    - Compute
"""

import io
import os
import uuid
from pathlib import Path
from typing import Dict

import uvicorn
from fastapi import FastAPI, Form, HTTPException, UploadFile
from fastapi.responses import FileResponse, StreamingResponse

# No relative import here because when not used in the package itself
from concrete.ml.deployment import FHEModelServer

if __name__ == "__main__":
    app = FastAPI(debug=False)

    FILE_FOLDER = Path(__file__).parent

    KEY_PATH = Path(os.environ.get("KEY_PATH", FILE_FOLDER / Path("server_keys")))
    CLIENT_SERVER_PATH = Path(os.environ.get("PATH_TO_MODEL", FILE_FOLDER / Path("dev")))
    PORT = os.environ.get("PORT", "5000")

    fhe = FHEModelServer(str(CLIENT_SERVER_PATH.resolve()))

    KEYS: Dict[str, bytes] = {}

    PATH_TO_CLIENT = (CLIENT_SERVER_PATH / "client.zip").resolve()
    PATH_TO_SERVER = (CLIENT_SERVER_PATH / "server.zip").resolve()

    assert PATH_TO_CLIENT.exists()
    assert PATH_TO_SERVER.exists()

    @app.get("/get_client")
    def get_client():
        """Get client.

        Returns:
            FileResponse: client.zip

        Raises:
            HTTPException: if the file can't be find locally
        """
        path_to_client = (CLIENT_SERVER_PATH / "client.zip").resolve()
        if not path_to_client.exists():
            raise HTTPException(status_code=500, detail="Could not find client.")
        return FileResponse(path_to_client, media_type="application/zip")

    @app.post("/add_key")
    async def add_key(key: UploadFile):
        """Add public key.

        Arguments:
            key (UploadFile): public key

        Returns:
            Dict[str, str]
                - uid: uid a personal uid
        """
        uid = str(uuid.uuid4())
        KEYS[uid] = await key.read()
        return {"uid": uid}

    @app.post("/compute")
    async def compute(model_input: UploadFile, uid: str = Form()):  # noqa: B008
        """Compute the circuit over encrypted input.

        Arguments:
            model_input (UploadFile): input of the circuit
            uid (str): uid of the public key to use

        Returns:
            StreamingResponse: the result of the circuit
        """
        key = KEYS[uid]
        encrypted_results = fhe.run(
            serialized_encrypted_quantized_data=await model_input.read(),
            serialized_evaluation_keys=key,
        )
        return StreamingResponse(
            io.BytesIO(encrypted_results),
        )

    uvicorn.run(app, host="0.0.0.0", port=int(PORT))