Soutrik
added: client server on docker compose cpu tested
1b0bd15
raw
history blame
6.64 kB
import torch
from PIL import Image
import io
import litserve as lit
import base64
from torchvision import transforms
from src.models.catdog_model import ViTTinyClassifier
import hydra
from omegaconf import DictConfig, OmegaConf
from dotenv import load_dotenv, find_dotenv
import rootutils
from loguru import logger
from src.utils.logging_utils import setup_logger
from pathlib import Path
# Load environment variables
load_dotenv(find_dotenv(".env"))
# Setup root directory
root = rootutils.setup_root(__file__, indicator=".project-root")
logger.info(f"Root directory set to: {root}")
class ImageClassifierAPI(lit.LitAPI):
def __init__(self, cfg: DictConfig):
"""
Initialize the API with Hydra configuration.
"""
super().__init__()
self.cfg = cfg
# Validate required config keys
required_keys = ["ckpt_path", "data.image_size", "labels"]
missing_keys = [key for key in required_keys if not OmegaConf.select(cfg, key)]
if missing_keys:
logger.error(f"Missing required config keys: {missing_keys}")
raise ValueError(f"Missing required config keys: {missing_keys}")
logger.info(f"Configuration validated: {OmegaConf.to_yaml(cfg)}")
def setup(self, device):
"""Initialize the model and necessary components."""
self.device = device
logger.info("Setting up the model and components.")
# Log the configuration for debugging
logger.debug(f"Configuration passed to setup: {OmegaConf.to_yaml(self.cfg)}")
# Load the model from checkpoint
try:
self.model = ViTTinyClassifier.load_from_checkpoint(
checkpoint_path=self.cfg.ckpt_path
)
self.model = self.model.to(device).eval()
logger.info("Model loaded and moved to device.")
except FileNotFoundError:
logger.error(f"Checkpoint file not found: {self.cfg.ckpt_path}")
raise
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
# Define transforms
self.transforms = transforms.Compose(
[
transforms.Resize((self.cfg.data.image_size, self.cfg.data.image_size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], # Hard-coded mean
std=[0.229, 0.224, 0.225], # Hard-coded std
),
]
)
logger.info("Transforms initialized.")
# Load labels
try:
self.labels = self.cfg.labels
logger.info(f"Labels loaded: {self.labels}")
except Exception as e:
logger.error(f"Error loading labels: {e}")
raise ValueError("Failed to load labels from the configuration.")
def decode_request(self, request):
"""Handle both single and batch inputs."""
# logger.info(f"decode_request received: {request}")
if not isinstance(request, dict) or "image" not in request:
logger.error(
"Invalid request format. Expected a dictionary with key 'image'."
)
raise ValueError(
"Invalid request format. Expected a dictionary with key 'image'."
)
return request["image"]
def batch(self, inputs):
"""Batch process images."""
# logger.info(f"batch received inputs: {inputs}")
if not isinstance(inputs, list):
raise ValueError("Input to batch must be a list.")
batch_tensors = []
try:
for image_bytes in inputs:
if not isinstance(image_bytes, str): # Ensure input is a base64 string
raise ValueError(
f"Input must be a base64-encoded string, got: {type(image_bytes)}"
)
# Decode base64 string to bytes
img_bytes = base64.b64decode(image_bytes)
# Convert bytes to PIL Image
try:
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
except Exception as img_error:
logger.error(f"Failed to process image: {img_error}")
raise
# Apply transforms and add to batch
tensor = self.transforms(image)
batch_tensors.append(tensor)
return torch.stack(batch_tensors).to(self.device)
except Exception as e:
logger.error(f"Error decoding image: {e}")
raise ValueError("Failed to decode and process the images.")
def predict(self, x):
"""Make predictions on the input batch."""
with torch.inference_mode():
outputs = self.model(x)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
logger.info("Prediction completed.")
return probabilities
def unbatch(self, output):
"""Unbatch the output."""
return [output[i] for i in range(output.size(0))]
def encode_response(self, output):
"""Convert model output to API response for batches."""
try:
probs, indices = torch.topk(output, k=1)
responses = {
"predictions": [
{
"label": self.labels[idx.item()],
"probability": prob.item(),
}
for prob, idx in zip(probs, indices)
]
}
logger.info("Batch response successfully encoded.")
return responses
except Exception as e:
logger.error(f"Error encoding batch response: {e}")
raise ValueError("Failed to encode the batch response.")
@hydra.main(config_path="../configs", config_name="infer", version_base="1.3")
def main(cfg: DictConfig):
# Initialize loguru
setup_logger(Path(cfg.paths.log_dir) / "infer.log")
logger.info("Starting the Image Classifier API server.")
# Log configuration
logger.info(f"Configuration: {OmegaConf.to_yaml(cfg)}")
# Create the API instance with the Hydra config
api = ImageClassifierAPI(cfg)
# Configure the server
server = lit.LitServer(
api,
accelerator=cfg.server.accelerator,
max_batch_size=cfg.server.max_batch_size,
batch_timeout=cfg.server.batch_timeout,
devices=cfg.server.devices,
workers_per_device=cfg.server.workers_per_device,
)
server.run(port=cfg.server.port)
if __name__ == "__main__":
main()