Spaces:
Runtime error
Runtime error
Soutrik
commited on
Commit
•
6053557
1
Parent(s):
983c956
litserve
Browse files- configs/infer.yaml +10 -0
- src/client.py +85 -0
- src/infer.py +1 -1
- src/litserve_test_client.py +1 -2
- src/server.py +185 -0
configs/infer.yaml
CHANGED
@@ -40,3 +40,13 @@ seed: 42
|
|
40 |
|
41 |
# name of the experiment
|
42 |
name: "catdog_experiment"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
# name of the experiment
|
42 |
name: "catdog_experiment"
|
43 |
+
|
44 |
+
server:
|
45 |
+
port: 8080
|
46 |
+
max_batch_size: 8
|
47 |
+
batch_timeout: 0.01
|
48 |
+
accelerator: "auto"
|
49 |
+
devices: "auto"
|
50 |
+
workers_per_device: 2
|
51 |
+
|
52 |
+
labels: ["cat", "dog"]
|
src/client.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from urllib.request import urlopen
|
3 |
+
import base64
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
def fetch_image(url):
|
8 |
+
"""
|
9 |
+
Fetch image data from a URL.
|
10 |
+
"""
|
11 |
+
return urlopen(url).read()
|
12 |
+
|
13 |
+
|
14 |
+
def encode_image_to_base64(img_data):
|
15 |
+
"""
|
16 |
+
Encode image bytes to a base64 string.
|
17 |
+
"""
|
18 |
+
return base64.b64encode(img_data).decode("utf-8")
|
19 |
+
|
20 |
+
|
21 |
+
def send_prediction_request(base64_image, server_url):
|
22 |
+
"""
|
23 |
+
Send a single base64 image to the prediction API and retrieve predictions.
|
24 |
+
"""
|
25 |
+
try:
|
26 |
+
response = requests.post(f"{server_url}/predict", json={"image": base64_image})
|
27 |
+
return response
|
28 |
+
except requests.exceptions.RequestException as e:
|
29 |
+
print(f"Error connecting to the server: {e}")
|
30 |
+
return None
|
31 |
+
|
32 |
+
|
33 |
+
def send_batch_prediction_request(base64_images, server_url):
|
34 |
+
"""
|
35 |
+
Send a batch of base64 images to the prediction API and retrieve predictions.
|
36 |
+
"""
|
37 |
+
try:
|
38 |
+
response = requests.post(
|
39 |
+
f"{server_url}/predict", json=[{"image": img} for img in base64_images]
|
40 |
+
)
|
41 |
+
return response
|
42 |
+
except requests.exceptions.RequestException as e:
|
43 |
+
print(f"Error connecting to the server: {e}")
|
44 |
+
return None
|
45 |
+
|
46 |
+
|
47 |
+
def main():
|
48 |
+
# Server URL (default or from environment)
|
49 |
+
server_url = os.getenv("SERVER_URL", "http://localhost:8080")
|
50 |
+
|
51 |
+
# Example URLs for testing
|
52 |
+
image_urls = [
|
53 |
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
|
54 |
+
]
|
55 |
+
|
56 |
+
# Fetch and encode images
|
57 |
+
try:
|
58 |
+
print("Fetching and encoding images...")
|
59 |
+
base64_images = [encode_image_to_base64(fetch_image(url)) for url in image_urls]
|
60 |
+
print("Images fetched and encoded successfully.")
|
61 |
+
except Exception as e:
|
62 |
+
print(f"Error fetching or encoding images: {e}")
|
63 |
+
return
|
64 |
+
|
65 |
+
# Test single image prediction
|
66 |
+
try:
|
67 |
+
print("\n--- Single Image Prediction ---")
|
68 |
+
single_response = send_prediction_request(base64_images[0], server_url)
|
69 |
+
if single_response and single_response.status_code == 200:
|
70 |
+
predictions = single_response.json().get("predictions", [])
|
71 |
+
if predictions:
|
72 |
+
print("Top 5 Predictions:")
|
73 |
+
for pred in predictions:
|
74 |
+
print(f"{pred['label']}: {pred['probability']:.2%}")
|
75 |
+
else:
|
76 |
+
print("No predictions returned.")
|
77 |
+
elif single_response:
|
78 |
+
print(f"Error: {single_response.status_code}")
|
79 |
+
print(single_response.text)
|
80 |
+
except Exception as e:
|
81 |
+
print(f"Error sending single prediction request: {e}")
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
main()
|
src/infer.py
CHANGED
@@ -82,7 +82,7 @@ def download_image(cfg: DictConfig):
|
|
82 |
logger.error(f"Failed to download image. Status code: {response.status_code}")
|
83 |
|
84 |
|
85 |
-
@hydra.main(config_path="../configs", config_name="infer", version_base="1.
|
86 |
def main_infer(cfg: DictConfig):
|
87 |
# Print the configuration
|
88 |
logger.info(OmegaConf.to_yaml(cfg))
|
|
|
82 |
logger.error(f"Failed to download image. Status code: {response.status_code}")
|
83 |
|
84 |
|
85 |
+
@hydra.main(config_path="../configs", config_name="infer", version_base="1.3")
|
86 |
def main_infer(cfg: DictConfig):
|
87 |
# Print the configuration
|
88 |
logger.info(OmegaConf.to_yaml(cfg))
|
src/litserve_test_client.py
CHANGED
@@ -50,8 +50,7 @@ def main():
|
|
50 |
|
51 |
# Example URLs for testing
|
52 |
image_urls = [
|
53 |
-
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
|
54 |
-
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png",
|
55 |
]
|
56 |
|
57 |
# Fetch and encode images
|
|
|
50 |
|
51 |
# Example URLs for testing
|
52 |
image_urls = [
|
53 |
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
|
|
|
54 |
]
|
55 |
|
56 |
# Fetch and encode images
|
src/server.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
import io
|
4 |
+
import litserve as lit
|
5 |
+
import base64
|
6 |
+
from torchvision import transforms
|
7 |
+
from src.models.catdog_model import ViTTinyClassifier
|
8 |
+
import hydra
|
9 |
+
from omegaconf import DictConfig, OmegaConf
|
10 |
+
from dotenv import load_dotenv, find_dotenv
|
11 |
+
import rootutils
|
12 |
+
from loguru import logger
|
13 |
+
from src.utils.logging_utils import setup_logger
|
14 |
+
from pathlib import Path
|
15 |
+
|
16 |
+
# Load environment variables
|
17 |
+
load_dotenv(find_dotenv(".env"))
|
18 |
+
|
19 |
+
# Setup root directory
|
20 |
+
root = rootutils.setup_root(__file__, indicator=".project-root")
|
21 |
+
logger.info(f"Root directory set to: {root}")
|
22 |
+
|
23 |
+
|
24 |
+
class ImageClassifierAPI(lit.LitAPI):
|
25 |
+
def __init__(self, cfg: DictConfig):
|
26 |
+
"""
|
27 |
+
Initialize the API with Hydra configuration.
|
28 |
+
"""
|
29 |
+
super().__init__()
|
30 |
+
self.cfg = cfg
|
31 |
+
|
32 |
+
# Validate required config keys
|
33 |
+
required_keys = ["ckpt_path", "data.image_size", "labels"]
|
34 |
+
missing_keys = [key for key in required_keys if not OmegaConf.select(cfg, key)]
|
35 |
+
if missing_keys:
|
36 |
+
logger.error(f"Missing required config keys: {missing_keys}")
|
37 |
+
raise ValueError(f"Missing required config keys: {missing_keys}")
|
38 |
+
logger.info(f"Configuration validated: {OmegaConf.to_yaml(cfg)}")
|
39 |
+
|
40 |
+
def setup(self, device):
|
41 |
+
"""Initialize the model and necessary components."""
|
42 |
+
self.device = device
|
43 |
+
logger.info("Setting up the model and components.")
|
44 |
+
|
45 |
+
# Log the configuration for debugging
|
46 |
+
logger.debug(f"Configuration passed to setup: {OmegaConf.to_yaml(self.cfg)}")
|
47 |
+
|
48 |
+
# Load the model from checkpoint
|
49 |
+
try:
|
50 |
+
self.model = ViTTinyClassifier.load_from_checkpoint(
|
51 |
+
checkpoint_path=self.cfg.ckpt_path
|
52 |
+
)
|
53 |
+
self.model = self.model.to(device).eval()
|
54 |
+
logger.info("Model loaded and moved to device.")
|
55 |
+
except FileNotFoundError:
|
56 |
+
logger.error(f"Checkpoint file not found: {self.cfg.ckpt_path}")
|
57 |
+
raise
|
58 |
+
except Exception as e:
|
59 |
+
logger.error(f"Error loading model: {e}")
|
60 |
+
raise
|
61 |
+
|
62 |
+
# Define transforms
|
63 |
+
self.transforms = transforms.Compose(
|
64 |
+
[
|
65 |
+
transforms.Resize((self.cfg.data.image_size, self.cfg.data.image_size)),
|
66 |
+
transforms.ToTensor(),
|
67 |
+
transforms.Normalize(
|
68 |
+
mean=[0.485, 0.456, 0.406], # Hard-coded mean
|
69 |
+
std=[0.229, 0.224, 0.225], # Hard-coded std
|
70 |
+
),
|
71 |
+
]
|
72 |
+
)
|
73 |
+
logger.info("Transforms initialized.")
|
74 |
+
|
75 |
+
# Load labels
|
76 |
+
try:
|
77 |
+
self.labels = self.cfg.labels
|
78 |
+
logger.info(f"Labels loaded: {self.labels}")
|
79 |
+
except Exception as e:
|
80 |
+
logger.error(f"Error loading labels: {e}")
|
81 |
+
raise ValueError("Failed to load labels from the configuration.")
|
82 |
+
|
83 |
+
def decode_request(self, request):
|
84 |
+
"""Handle both single and batch inputs."""
|
85 |
+
logger.info(f"decode_request received: {request}")
|
86 |
+
if not isinstance(request, dict) or "image" not in request:
|
87 |
+
logger.error(
|
88 |
+
"Invalid request format. Expected a dictionary with key 'image'."
|
89 |
+
)
|
90 |
+
raise ValueError(
|
91 |
+
"Invalid request format. Expected a dictionary with key 'image'."
|
92 |
+
)
|
93 |
+
return request["image"]
|
94 |
+
|
95 |
+
def batch(self, inputs):
|
96 |
+
"""Batch process images."""
|
97 |
+
logger.info(f"batch received inputs: {inputs}")
|
98 |
+
if not isinstance(inputs, list):
|
99 |
+
raise ValueError("Input to batch must be a list.")
|
100 |
+
|
101 |
+
batch_tensors = []
|
102 |
+
try:
|
103 |
+
for image_bytes in inputs:
|
104 |
+
if not isinstance(image_bytes, str): # Ensure input is a base64 string
|
105 |
+
raise ValueError(
|
106 |
+
f"Input must be a base64-encoded string, got: {type(image_bytes)}"
|
107 |
+
)
|
108 |
+
|
109 |
+
# Decode base64 string to bytes
|
110 |
+
img_bytes = base64.b64decode(image_bytes)
|
111 |
+
|
112 |
+
# Convert bytes to PIL Image
|
113 |
+
try:
|
114 |
+
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
115 |
+
except Exception as img_error:
|
116 |
+
logger.error(f"Failed to process image: {img_error}")
|
117 |
+
raise
|
118 |
+
|
119 |
+
# Apply transforms and add to batch
|
120 |
+
tensor = self.transforms(image)
|
121 |
+
batch_tensors.append(tensor)
|
122 |
+
|
123 |
+
return torch.stack(batch_tensors).to(self.device)
|
124 |
+
except Exception as e:
|
125 |
+
logger.error(f"Error decoding image: {e}")
|
126 |
+
raise ValueError("Failed to decode and process the images.")
|
127 |
+
|
128 |
+
def predict(self, x):
|
129 |
+
"""Make predictions on the input batch."""
|
130 |
+
with torch.inference_mode():
|
131 |
+
outputs = self.model(x)
|
132 |
+
probabilities = torch.nn.functional.softmax(outputs, dim=1)
|
133 |
+
logger.info("Prediction completed.")
|
134 |
+
return probabilities
|
135 |
+
|
136 |
+
def unbatch(self, output):
|
137 |
+
"""Unbatch the output."""
|
138 |
+
return [output[i] for i in range(output.size(0))]
|
139 |
+
|
140 |
+
def encode_response(self, output):
|
141 |
+
"""Convert model output to API response for batches."""
|
142 |
+
try:
|
143 |
+
probs, indices = torch.topk(output, k=1)
|
144 |
+
responses = {
|
145 |
+
"predictions": [
|
146 |
+
{
|
147 |
+
"label": self.labels[idx.item()],
|
148 |
+
"probability": prob.item(),
|
149 |
+
}
|
150 |
+
for prob, idx in zip(probs, indices)
|
151 |
+
]
|
152 |
+
}
|
153 |
+
logger.info("Batch response successfully encoded.")
|
154 |
+
return responses
|
155 |
+
except Exception as e:
|
156 |
+
logger.error(f"Error encoding batch response: {e}")
|
157 |
+
raise ValueError("Failed to encode the batch response.")
|
158 |
+
|
159 |
+
|
160 |
+
@hydra.main(config_path="../configs", config_name="infer", version_base="1.3")
|
161 |
+
def main(cfg: DictConfig):
|
162 |
+
# Initialize loguru
|
163 |
+
setup_logger(Path(cfg.paths.log_dir) / "infer.log")
|
164 |
+
logger.info("Starting the Image Classifier API server.")
|
165 |
+
|
166 |
+
# Log configuration
|
167 |
+
logger.info(f"Configuration: {OmegaConf.to_yaml(cfg)}")
|
168 |
+
|
169 |
+
# Create the API instance with the Hydra config
|
170 |
+
api = ImageClassifierAPI(cfg)
|
171 |
+
|
172 |
+
# Configure the server
|
173 |
+
server = lit.LitServer(
|
174 |
+
api,
|
175 |
+
accelerator=cfg.server.accelerator,
|
176 |
+
max_batch_size=cfg.server.max_batch_size,
|
177 |
+
batch_timeout=cfg.server.batch_timeout,
|
178 |
+
devices=cfg.server.devices,
|
179 |
+
workers_per_device=cfg.server.workers_per_device,
|
180 |
+
)
|
181 |
+
server.run(port=cfg.server.port)
|
182 |
+
|
183 |
+
|
184 |
+
if __name__ == "__main__":
|
185 |
+
main()
|