Soutrik commited on
Commit
6053557
1 Parent(s): 983c956
configs/infer.yaml CHANGED
@@ -40,3 +40,13 @@ seed: 42
40
 
41
  # name of the experiment
42
  name: "catdog_experiment"
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  # name of the experiment
42
  name: "catdog_experiment"
43
+
44
+ server:
45
+ port: 8080
46
+ max_batch_size: 8
47
+ batch_timeout: 0.01
48
+ accelerator: "auto"
49
+ devices: "auto"
50
+ workers_per_device: 2
51
+
52
+ labels: ["cat", "dog"]
src/client.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from urllib.request import urlopen
3
+ import base64
4
+ import os
5
+
6
+
7
+ def fetch_image(url):
8
+ """
9
+ Fetch image data from a URL.
10
+ """
11
+ return urlopen(url).read()
12
+
13
+
14
+ def encode_image_to_base64(img_data):
15
+ """
16
+ Encode image bytes to a base64 string.
17
+ """
18
+ return base64.b64encode(img_data).decode("utf-8")
19
+
20
+
21
+ def send_prediction_request(base64_image, server_url):
22
+ """
23
+ Send a single base64 image to the prediction API and retrieve predictions.
24
+ """
25
+ try:
26
+ response = requests.post(f"{server_url}/predict", json={"image": base64_image})
27
+ return response
28
+ except requests.exceptions.RequestException as e:
29
+ print(f"Error connecting to the server: {e}")
30
+ return None
31
+
32
+
33
+ def send_batch_prediction_request(base64_images, server_url):
34
+ """
35
+ Send a batch of base64 images to the prediction API and retrieve predictions.
36
+ """
37
+ try:
38
+ response = requests.post(
39
+ f"{server_url}/predict", json=[{"image": img} for img in base64_images]
40
+ )
41
+ return response
42
+ except requests.exceptions.RequestException as e:
43
+ print(f"Error connecting to the server: {e}")
44
+ return None
45
+
46
+
47
+ def main():
48
+ # Server URL (default or from environment)
49
+ server_url = os.getenv("SERVER_URL", "http://localhost:8080")
50
+
51
+ # Example URLs for testing
52
+ image_urls = [
53
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
54
+ ]
55
+
56
+ # Fetch and encode images
57
+ try:
58
+ print("Fetching and encoding images...")
59
+ base64_images = [encode_image_to_base64(fetch_image(url)) for url in image_urls]
60
+ print("Images fetched and encoded successfully.")
61
+ except Exception as e:
62
+ print(f"Error fetching or encoding images: {e}")
63
+ return
64
+
65
+ # Test single image prediction
66
+ try:
67
+ print("\n--- Single Image Prediction ---")
68
+ single_response = send_prediction_request(base64_images[0], server_url)
69
+ if single_response and single_response.status_code == 200:
70
+ predictions = single_response.json().get("predictions", [])
71
+ if predictions:
72
+ print("Top 5 Predictions:")
73
+ for pred in predictions:
74
+ print(f"{pred['label']}: {pred['probability']:.2%}")
75
+ else:
76
+ print("No predictions returned.")
77
+ elif single_response:
78
+ print(f"Error: {single_response.status_code}")
79
+ print(single_response.text)
80
+ except Exception as e:
81
+ print(f"Error sending single prediction request: {e}")
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()
src/infer.py CHANGED
@@ -82,7 +82,7 @@ def download_image(cfg: DictConfig):
82
  logger.error(f"Failed to download image. Status code: {response.status_code}")
83
 
84
 
85
- @hydra.main(config_path="../configs", config_name="infer", version_base="1.1")
86
  def main_infer(cfg: DictConfig):
87
  # Print the configuration
88
  logger.info(OmegaConf.to_yaml(cfg))
 
82
  logger.error(f"Failed to download image. Status code: {response.status_code}")
83
 
84
 
85
+ @hydra.main(config_path="../configs", config_name="infer", version_base="1.3")
86
  def main_infer(cfg: DictConfig):
87
  # Print the configuration
88
  logger.info(OmegaConf.to_yaml(cfg))
src/litserve_test_client.py CHANGED
@@ -50,8 +50,7 @@ def main():
50
 
51
  # Example URLs for testing
52
  image_urls = [
53
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png",
54
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png",
55
  ]
56
 
57
  # Fetch and encode images
 
50
 
51
  # Example URLs for testing
52
  image_urls = [
53
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
 
54
  ]
55
 
56
  # Fetch and encode images
src/server.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import io
4
+ import litserve as lit
5
+ import base64
6
+ from torchvision import transforms
7
+ from src.models.catdog_model import ViTTinyClassifier
8
+ import hydra
9
+ from omegaconf import DictConfig, OmegaConf
10
+ from dotenv import load_dotenv, find_dotenv
11
+ import rootutils
12
+ from loguru import logger
13
+ from src.utils.logging_utils import setup_logger
14
+ from pathlib import Path
15
+
16
+ # Load environment variables
17
+ load_dotenv(find_dotenv(".env"))
18
+
19
+ # Setup root directory
20
+ root = rootutils.setup_root(__file__, indicator=".project-root")
21
+ logger.info(f"Root directory set to: {root}")
22
+
23
+
24
+ class ImageClassifierAPI(lit.LitAPI):
25
+ def __init__(self, cfg: DictConfig):
26
+ """
27
+ Initialize the API with Hydra configuration.
28
+ """
29
+ super().__init__()
30
+ self.cfg = cfg
31
+
32
+ # Validate required config keys
33
+ required_keys = ["ckpt_path", "data.image_size", "labels"]
34
+ missing_keys = [key for key in required_keys if not OmegaConf.select(cfg, key)]
35
+ if missing_keys:
36
+ logger.error(f"Missing required config keys: {missing_keys}")
37
+ raise ValueError(f"Missing required config keys: {missing_keys}")
38
+ logger.info(f"Configuration validated: {OmegaConf.to_yaml(cfg)}")
39
+
40
+ def setup(self, device):
41
+ """Initialize the model and necessary components."""
42
+ self.device = device
43
+ logger.info("Setting up the model and components.")
44
+
45
+ # Log the configuration for debugging
46
+ logger.debug(f"Configuration passed to setup: {OmegaConf.to_yaml(self.cfg)}")
47
+
48
+ # Load the model from checkpoint
49
+ try:
50
+ self.model = ViTTinyClassifier.load_from_checkpoint(
51
+ checkpoint_path=self.cfg.ckpt_path
52
+ )
53
+ self.model = self.model.to(device).eval()
54
+ logger.info("Model loaded and moved to device.")
55
+ except FileNotFoundError:
56
+ logger.error(f"Checkpoint file not found: {self.cfg.ckpt_path}")
57
+ raise
58
+ except Exception as e:
59
+ logger.error(f"Error loading model: {e}")
60
+ raise
61
+
62
+ # Define transforms
63
+ self.transforms = transforms.Compose(
64
+ [
65
+ transforms.Resize((self.cfg.data.image_size, self.cfg.data.image_size)),
66
+ transforms.ToTensor(),
67
+ transforms.Normalize(
68
+ mean=[0.485, 0.456, 0.406], # Hard-coded mean
69
+ std=[0.229, 0.224, 0.225], # Hard-coded std
70
+ ),
71
+ ]
72
+ )
73
+ logger.info("Transforms initialized.")
74
+
75
+ # Load labels
76
+ try:
77
+ self.labels = self.cfg.labels
78
+ logger.info(f"Labels loaded: {self.labels}")
79
+ except Exception as e:
80
+ logger.error(f"Error loading labels: {e}")
81
+ raise ValueError("Failed to load labels from the configuration.")
82
+
83
+ def decode_request(self, request):
84
+ """Handle both single and batch inputs."""
85
+ logger.info(f"decode_request received: {request}")
86
+ if not isinstance(request, dict) or "image" not in request:
87
+ logger.error(
88
+ "Invalid request format. Expected a dictionary with key 'image'."
89
+ )
90
+ raise ValueError(
91
+ "Invalid request format. Expected a dictionary with key 'image'."
92
+ )
93
+ return request["image"]
94
+
95
+ def batch(self, inputs):
96
+ """Batch process images."""
97
+ logger.info(f"batch received inputs: {inputs}")
98
+ if not isinstance(inputs, list):
99
+ raise ValueError("Input to batch must be a list.")
100
+
101
+ batch_tensors = []
102
+ try:
103
+ for image_bytes in inputs:
104
+ if not isinstance(image_bytes, str): # Ensure input is a base64 string
105
+ raise ValueError(
106
+ f"Input must be a base64-encoded string, got: {type(image_bytes)}"
107
+ )
108
+
109
+ # Decode base64 string to bytes
110
+ img_bytes = base64.b64decode(image_bytes)
111
+
112
+ # Convert bytes to PIL Image
113
+ try:
114
+ image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
115
+ except Exception as img_error:
116
+ logger.error(f"Failed to process image: {img_error}")
117
+ raise
118
+
119
+ # Apply transforms and add to batch
120
+ tensor = self.transforms(image)
121
+ batch_tensors.append(tensor)
122
+
123
+ return torch.stack(batch_tensors).to(self.device)
124
+ except Exception as e:
125
+ logger.error(f"Error decoding image: {e}")
126
+ raise ValueError("Failed to decode and process the images.")
127
+
128
+ def predict(self, x):
129
+ """Make predictions on the input batch."""
130
+ with torch.inference_mode():
131
+ outputs = self.model(x)
132
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
133
+ logger.info("Prediction completed.")
134
+ return probabilities
135
+
136
+ def unbatch(self, output):
137
+ """Unbatch the output."""
138
+ return [output[i] for i in range(output.size(0))]
139
+
140
+ def encode_response(self, output):
141
+ """Convert model output to API response for batches."""
142
+ try:
143
+ probs, indices = torch.topk(output, k=1)
144
+ responses = {
145
+ "predictions": [
146
+ {
147
+ "label": self.labels[idx.item()],
148
+ "probability": prob.item(),
149
+ }
150
+ for prob, idx in zip(probs, indices)
151
+ ]
152
+ }
153
+ logger.info("Batch response successfully encoded.")
154
+ return responses
155
+ except Exception as e:
156
+ logger.error(f"Error encoding batch response: {e}")
157
+ raise ValueError("Failed to encode the batch response.")
158
+
159
+
160
+ @hydra.main(config_path="../configs", config_name="infer", version_base="1.3")
161
+ def main(cfg: DictConfig):
162
+ # Initialize loguru
163
+ setup_logger(Path(cfg.paths.log_dir) / "infer.log")
164
+ logger.info("Starting the Image Classifier API server.")
165
+
166
+ # Log configuration
167
+ logger.info(f"Configuration: {OmegaConf.to_yaml(cfg)}")
168
+
169
+ # Create the API instance with the Hydra config
170
+ api = ImageClassifierAPI(cfg)
171
+
172
+ # Configure the server
173
+ server = lit.LitServer(
174
+ api,
175
+ accelerator=cfg.server.accelerator,
176
+ max_batch_size=cfg.server.max_batch_size,
177
+ batch_timeout=cfg.server.batch_timeout,
178
+ devices=cfg.server.devices,
179
+ workers_per_device=cfg.server.workers_per_device,
180
+ )
181
+ server.run(port=cfg.server.port)
182
+
183
+
184
+ if __name__ == "__main__":
185
+ main()