File size: 3,364 Bytes
743fc77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

"""Client script.

This script does the following:
    - Query crypto-parameters and pre/post-processing parameters (client.zip)
    - Quantize the inputs using the parameters
    - Encrypt data using the crypto-parameters
    - Send the encrypted data to the server (async using grequests)
    - Collect the data and decrypt it
    - De-quantize the decrypted results
"""

import io
import os
import sys
from pathlib import Path

import grequests
import numpy
import requests
import torch
import torchvision
import torchvision.transforms as transforms

from concrete.ml.deployment import FHEModelClient

PORT = os.environ.get("PORT", "5000")
IP = os.environ.get("IP", "localhost")
URL = os.environ.get("URL", f"http://{IP}:{PORT}")
NUM_SAMPLES = int(os.environ.get("NUM_SAMPLES", 1))
STATUS_OK = 200


def main():
    # Get the necessary data for the client
    # client.zip

    train_sub_set = ...

    zip_response = requests.get(f"{URL}/get_client")
    assert zip_response.status_code == STATUS_OK
    with open("./client.zip", "wb") as file:
        file.write(zip_response.content)

    # Get the data to infer
    X = train_sub_set[:1]

    # Create the client
    client = FHEModelClient(path_dir="./", key_dir="./keys")

    # The client first need to create the private and evaluation keys.
    serialized_evaluation_keys = client.get_serialized_evaluation_keys()

    assert isinstance(serialized_evaluation_keys, bytes)

    # Evaluation keys can be quite large files but only have to be shared once with the server.

    # Check the size of the evaluation keys (in MB)
    print(f"Evaluation keys size: {sys.getsizeof(serialized_evaluation_keys) / 1024 / 1024:.2f} MB")

    # Update all base64 queries encodings with UploadFile
    response = requests.post(
        f"{URL}/add_key",
        files={"key": io.BytesIO(initial_bytes=serialized_evaluation_keys)},
    )
    assert response.status_code == STATUS_OK
    uid = response.json()["uid"]

    inferences = []
    # Launch the queries
    clear_input = X[[0], :].numpy()
    print("Input shape:", clear_input.shape)

    assert isinstance(clear_input, numpy.ndarray)
    print("Quantize/Encrypt")
    encrypted_input = client.quantize_encrypt_serialize(clear_input) # Encrypt the data
    assert isinstance(encrypted_input, bytes)

    print(f"Encrypted input size: {sys.getsizeof(encrypted_input) / 1024 / 1024:.2f} MB")

    print("Posting query")
    inferences.append(
        grequests.post(
            f"{URL}/compute",
            files={
                "model_input": io.BytesIO(encrypted_input),
            },
            data={
                "uid": uid,
            },
        )
    )

    del encrypted_input
    del serialized_evaluation_keys

    print("Posted!")

    # Unpack the results
    decrypted_predictions = []
    for result in grequests.map(inferences):
        if result is None:
            raise ValueError(
                "Result is None, probably because the server crashed due to lack of available memory."
            )
        assert result.status_code == STATUS_OK
        print("OK!")

        encrypted_result = result.content
        decrypted_prediction = client.deserialize_decrypt_dequantize(encrypted_result)[0]
        decrypted_predictions.append(decrypted_prediction)
    print(decrypted_predictions)


if __name__ == "__main__":
    main()