Spaces:
Runtime error
Runtime error
import torch | |
from PIL import Image | |
import io | |
import litserve as lit | |
import base64 | |
from torchvision import transforms | |
from src.models.catdog_model import ViTTinyClassifier | |
import hydra | |
from omegaconf import DictConfig, OmegaConf | |
from dotenv import load_dotenv, find_dotenv | |
import rootutils | |
from loguru import logger | |
from src.utils.logging_utils import setup_logger | |
from pathlib import Path | |
# Load environment variables | |
load_dotenv(find_dotenv(".env")) | |
# Setup root directory | |
root = rootutils.setup_root(__file__, indicator=".project-root") | |
logger.info(f"Root directory set to: {root}") | |
class ImageClassifierAPI(lit.LitAPI): | |
def __init__(self, cfg: DictConfig): | |
""" | |
Initialize the API with Hydra configuration. | |
""" | |
super().__init__() | |
self.cfg = cfg | |
# Validate required config keys | |
required_keys = ["ckpt_path", "data.image_size", "labels"] | |
missing_keys = [key for key in required_keys if not OmegaConf.select(cfg, key)] | |
if missing_keys: | |
logger.error(f"Missing required config keys: {missing_keys}") | |
raise ValueError(f"Missing required config keys: {missing_keys}") | |
logger.info(f"Configuration validated: {OmegaConf.to_yaml(cfg)}") | |
def setup(self, device): | |
"""Initialize the model and necessary components.""" | |
self.device = device | |
logger.info("Setting up the model and components.") | |
# Log the configuration for debugging | |
logger.debug(f"Configuration passed to setup: {OmegaConf.to_yaml(self.cfg)}") | |
# Load the model from checkpoint | |
try: | |
self.model = ViTTinyClassifier.load_from_checkpoint( | |
checkpoint_path=self.cfg.ckpt_path | |
) | |
self.model = self.model.to(device).eval() | |
logger.info("Model loaded and moved to device.") | |
except FileNotFoundError: | |
logger.error(f"Checkpoint file not found: {self.cfg.ckpt_path}") | |
raise | |
except Exception as e: | |
logger.error(f"Error loading model: {e}") | |
raise | |
# Define transforms | |
self.transforms = transforms.Compose( | |
[ | |
transforms.Resize((self.cfg.data.image_size, self.cfg.data.image_size)), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], # Hard-coded mean | |
std=[0.229, 0.224, 0.225], # Hard-coded std | |
), | |
] | |
) | |
logger.info("Transforms initialized.") | |
# Load labels | |
try: | |
self.labels = self.cfg.labels | |
logger.info(f"Labels loaded: {self.labels}") | |
except Exception as e: | |
logger.error(f"Error loading labels: {e}") | |
raise ValueError("Failed to load labels from the configuration.") | |
def decode_request(self, request): | |
"""Handle both single and batch inputs.""" | |
# logger.info(f"decode_request received: {request}") | |
if not isinstance(request, dict) or "image" not in request: | |
logger.error( | |
"Invalid request format. Expected a dictionary with key 'image'." | |
) | |
raise ValueError( | |
"Invalid request format. Expected a dictionary with key 'image'." | |
) | |
return request["image"] | |
def batch(self, inputs): | |
"""Batch process images.""" | |
# logger.info(f"batch received inputs: {inputs}") | |
if not isinstance(inputs, list): | |
raise ValueError("Input to batch must be a list.") | |
batch_tensors = [] | |
try: | |
for image_bytes in inputs: | |
if not isinstance(image_bytes, str): # Ensure input is a base64 string | |
raise ValueError( | |
f"Input must be a base64-encoded string, got: {type(image_bytes)}" | |
) | |
# Decode base64 string to bytes | |
img_bytes = base64.b64decode(image_bytes) | |
# Convert bytes to PIL Image | |
try: | |
image = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
except Exception as img_error: | |
logger.error(f"Failed to process image: {img_error}") | |
raise | |
# Apply transforms and add to batch | |
tensor = self.transforms(image) | |
batch_tensors.append(tensor) | |
return torch.stack(batch_tensors).to(self.device) | |
except Exception as e: | |
logger.error(f"Error decoding image: {e}") | |
raise ValueError("Failed to decode and process the images.") | |
def predict(self, x): | |
"""Make predictions on the input batch.""" | |
with torch.inference_mode(): | |
outputs = self.model(x) | |
probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
logger.info("Prediction completed.") | |
return probabilities | |
def unbatch(self, output): | |
"""Unbatch the output.""" | |
return [output[i] for i in range(output.size(0))] | |
def encode_response(self, output): | |
"""Convert model output to API response for batches.""" | |
try: | |
probs, indices = torch.topk(output, k=1) | |
responses = { | |
"predictions": [ | |
{ | |
"label": self.labels[idx.item()], | |
"probability": prob.item(), | |
} | |
for prob, idx in zip(probs, indices) | |
] | |
} | |
logger.info("Batch response successfully encoded.") | |
return responses | |
except Exception as e: | |
logger.error(f"Error encoding batch response: {e}") | |
raise ValueError("Failed to encode the batch response.") | |
def main(cfg: DictConfig): | |
# Initialize loguru | |
setup_logger(Path(cfg.paths.log_dir) / "infer.log") | |
logger.info("Starting the Image Classifier API server.") | |
# Log configuration | |
logger.info(f"Configuration: {OmegaConf.to_yaml(cfg)}") | |
# Create the API instance with the Hydra config | |
api = ImageClassifierAPI(cfg) | |
# Configure the server | |
server = lit.LitServer( | |
api, | |
accelerator=cfg.server.accelerator, | |
max_batch_size=cfg.server.max_batch_size, | |
batch_timeout=cfg.server.batch_timeout, | |
devices=cfg.server.devices, | |
workers_per_device=cfg.server.workers_per_device, | |
) | |
server.run(port=cfg.server.port) | |
if __name__ == "__main__": | |
main() | |