Spaces:
Runtime error
Runtime error
File size: 4,408 Bytes
16d3463 983c956 16d3463 983c956 16d3463 983c956 16d3463 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import torch
import timm
from PIL import Image
import io
import litserve as lit
import base64
import requests
import logging
class ImageClassifierAPI(lit.LitAPI):
def setup(self, device):
"""Initialize the model and necessary components."""
self.device = device
logging.info("Setting up the model and components.")
# Create and load the model
self.model = timm.create_model("resnet50.a1_in1k", pretrained=True)
self.model = self.model.to(device).eval()
# Disable gradients to save memory
with torch.no_grad():
data_config = timm.data.resolve_model_data_config(self.model)
self.transforms = timm.data.create_transform(
**data_config, is_training=False
)
# Load labels
url = "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
try:
self.labels = requests.get(url).text.strip().split("\n")
logging.info("Labels loaded successfully.")
except Exception as e:
logging.error(f"Failed to load labels: {e}")
self.labels = []
def decode_request(self, request):
"""Handle both single and batch inputs."""
logging.info(f"decode_request received: {request}")
if isinstance(request, dict):
return request["image"]
def batch(self, inputs):
"""Batch process images."""
logging.info(f"batch received inputs: {inputs}")
if not isinstance(inputs, list):
raise ValueError("Input to batch must be a list.")
batch_tensors = []
try:
for image_bytes in inputs:
if not isinstance(image_bytes, str): # Ensure input is a base64 string
raise ValueError(
f"Input must be a base64-encoded string, got: {type(image_bytes)}"
)
# Decode base64 string to bytes
img_bytes = base64.b64decode(image_bytes)
# Convert bytes to PIL Image
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
# Apply transforms and add to batch
tensor = self.transforms(image)
batch_tensors.append(tensor)
return torch.stack(batch_tensors).to(self.device)
except Exception as e:
logging.error(f"Error decoding image: {e}")
raise ValueError("Failed to decode and process the images.")
@torch.no_grad()
def predict(self, x):
"""Make predictions on the input batch."""
outputs = self.model(x)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
logging.info("Prediction completed.")
return probabilities
def unbatch(self, output):
"""Unbatch the output."""
return [output[i] for i in range(output.size(0))]
def encode_response(self, output):
"""Convert model output to API response for batches."""
try:
probs, indices = torch.topk(output, k=5)
responses = {
"predictions": [
{
"label": self.labels[idx.item()],
"probability": prob.item(),
}
for prob, idx in zip(probs, indices)
]
}
logging.info("Batch response successfully encoded.")
return responses
except Exception as e:
logging.error(f"Error encoding batch response: {e}")
raise ValueError("Failed to encode the batch response.")
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logging.info("Starting the Image Classifier API server.")
api = ImageClassifierAPI()
# basic server setup
# server = lit.LitServer(api, accelerator="auto",devices='auto')
# Configure server with optimal settings
# server = lit.LitServer(
# api, accelerator="auto", max_batch_size=16, batch_timeout=0.01, devices="auto"
# )
# increase the number of workers to handle more requests
server = lit.LitServer(
api,
accelerator="auto",
max_batch_size=4,
batch_timeout=0.01,
devices="auto",
workers_per_device=2,
)
server.run(port=8080)
|