Spaces:
Runtime error
Runtime error
import torch | |
import timm | |
from PIL import Image | |
import io | |
import litserve as lit | |
import base64 | |
import requests | |
import logging | |
class ImageClassifierAPI(lit.LitAPI): | |
def setup(self, device): | |
"""Initialize the model and necessary components.""" | |
self.device = device | |
logging.info("Setting up the model and components.") | |
# Create and load the model | |
self.model = timm.create_model("resnet50.a1_in1k", pretrained=True) | |
self.model = self.model.to(device).eval() | |
# Disable gradients to save memory | |
with torch.no_grad(): | |
data_config = timm.data.resolve_model_data_config(self.model) | |
self.transforms = timm.data.create_transform( | |
**data_config, is_training=False | |
) | |
# Load labels | |
url = "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt" | |
try: | |
self.labels = requests.get(url).text.strip().split("\n") | |
logging.info("Labels loaded successfully.") | |
except Exception as e: | |
logging.error(f"Failed to load labels: {e}") | |
self.labels = [] | |
def decode_request(self, request): | |
"""Handle both single and batch inputs.""" | |
logging.info(f"decode_request received: {request}") | |
if isinstance(request, dict): | |
return request["image"] | |
def batch(self, inputs): | |
"""Batch process images.""" | |
logging.info(f"batch received inputs: {inputs}") | |
if not isinstance(inputs, list): | |
raise ValueError("Input to batch must be a list.") | |
batch_tensors = [] | |
try: | |
for image_bytes in inputs: | |
if not isinstance(image_bytes, str): # Ensure input is a base64 string | |
raise ValueError( | |
f"Input must be a base64-encoded string, got: {type(image_bytes)}" | |
) | |
# Decode base64 string to bytes | |
img_bytes = base64.b64decode(image_bytes) | |
# Convert bytes to PIL Image | |
image = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
# Apply transforms and add to batch | |
tensor = self.transforms(image) | |
batch_tensors.append(tensor) | |
return torch.stack(batch_tensors).to(self.device) | |
except Exception as e: | |
logging.error(f"Error decoding image: {e}") | |
raise ValueError("Failed to decode and process the images.") | |
def predict(self, x): | |
"""Make predictions on the input batch.""" | |
outputs = self.model(x) | |
probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
logging.info("Prediction completed.") | |
return probabilities | |
def unbatch(self, output): | |
"""Unbatch the output.""" | |
return [output[i] for i in range(output.size(0))] | |
def encode_response(self, output): | |
"""Convert model output to API response for batches.""" | |
try: | |
probs, indices = torch.topk(output, k=5) | |
responses = { | |
"predictions": [ | |
{ | |
"label": self.labels[idx.item()], | |
"probability": prob.item(), | |
} | |
for prob, idx in zip(probs, indices) | |
] | |
} | |
logging.info("Batch response successfully encoded.") | |
return responses | |
except Exception as e: | |
logging.error(f"Error encoding batch response: {e}") | |
raise ValueError("Failed to encode the batch response.") | |
if __name__ == "__main__": | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
logging.info("Starting the Image Classifier API server.") | |
api = ImageClassifierAPI() | |
# basic server setup | |
# server = lit.LitServer(api, accelerator="auto",devices='auto') | |
# Configure server with optimal settings | |
# server = lit.LitServer( | |
# api, accelerator="auto", max_batch_size=16, batch_timeout=0.01, devices="auto" | |
# ) | |
# increase the number of workers to handle more requests | |
server = lit.LitServer( | |
api, | |
accelerator="auto", | |
max_batch_size=4, | |
batch_timeout=0.01, | |
devices="auto", | |
workers_per_device=2, | |
) | |
server.run(port=8080) | |