File size: 4,408 Bytes
16d3463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
983c956
 
16d3463
983c956
 
 
 
16d3463
983c956
 
 
 
 
 
16d3463
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import timm
from PIL import Image
import io
import litserve as lit
import base64
import requests
import logging


class ImageClassifierAPI(lit.LitAPI):
    def setup(self, device):
        """Initialize the model and necessary components."""
        self.device = device
        logging.info("Setting up the model and components.")

        # Create and load the model
        self.model = timm.create_model("resnet50.a1_in1k", pretrained=True)
        self.model = self.model.to(device).eval()

        # Disable gradients to save memory
        with torch.no_grad():
            data_config = timm.data.resolve_model_data_config(self.model)
            self.transforms = timm.data.create_transform(
                **data_config, is_training=False
            )

        # Load labels
        url = "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
        try:
            self.labels = requests.get(url).text.strip().split("\n")
            logging.info("Labels loaded successfully.")
        except Exception as e:
            logging.error(f"Failed to load labels: {e}")
            self.labels = []

    def decode_request(self, request):
        """Handle both single and batch inputs."""
        logging.info(f"decode_request received: {request}")
        if isinstance(request, dict):
            return request["image"]

    def batch(self, inputs):
        """Batch process images."""
        logging.info(f"batch received inputs: {inputs}")
        if not isinstance(inputs, list):
            raise ValueError("Input to batch must be a list.")

        batch_tensors = []
        try:
            for image_bytes in inputs:
                if not isinstance(image_bytes, str):  # Ensure input is a base64 string
                    raise ValueError(
                        f"Input must be a base64-encoded string, got: {type(image_bytes)}"
                    )

                # Decode base64 string to bytes
                img_bytes = base64.b64decode(image_bytes)

                # Convert bytes to PIL Image
                image = Image.open(io.BytesIO(img_bytes)).convert("RGB")

                # Apply transforms and add to batch
                tensor = self.transforms(image)
                batch_tensors.append(tensor)

            return torch.stack(batch_tensors).to(self.device)
        except Exception as e:
            logging.error(f"Error decoding image: {e}")
            raise ValueError("Failed to decode and process the images.")

    @torch.no_grad()
    def predict(self, x):
        """Make predictions on the input batch."""
        outputs = self.model(x)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        logging.info("Prediction completed.")
        return probabilities

    def unbatch(self, output):
        """Unbatch the output."""
        return [output[i] for i in range(output.size(0))]

    def encode_response(self, output):
        """Convert model output to API response for batches."""
        try:
            probs, indices = torch.topk(output, k=5)
            responses = {
                "predictions": [
                    {
                        "label": self.labels[idx.item()],
                        "probability": prob.item(),
                    }
                    for prob, idx in zip(probs, indices)
                ]
            }
            logging.info("Batch response successfully encoded.")
            return responses
        except Exception as e:
            logging.error(f"Error encoding batch response: {e}")
            raise ValueError("Failed to encode the batch response.")


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
    )
    logging.info("Starting the Image Classifier API server.")

    api = ImageClassifierAPI()

    # basic server setup
    # server = lit.LitServer(api, accelerator="auto",devices='auto')
    # Configure server with optimal settings
    # server = lit.LitServer(
    #     api, accelerator="auto", max_batch_size=16, batch_timeout=0.01, devices="auto"
    # )
    # increase the number of workers to handle more requests
    server = lit.LitServer(
        api,
        accelerator="auto",
        max_batch_size=4,
        batch_timeout=0.01,
        devices="auto",
        workers_per_device=2,
    )
    server.run(port=8080)