Spaces:
Sleeping
Sleeping
import torch | |
import joblib | |
from torchvision import transforms | |
from PIL import Image | |
from models.freshness_model import MultiOutputModel | |
import torch.nn.functional as F # For applying sigmoid | |
import gradio as gr | |
import json | |
# Define the paths | |
MODEL_PATH = 'models/multi_output_model.pth' | |
LABEL_ENCODER_TYPE_PATH = 'models/label_encoder_type.pkl' | |
# Load the label encoder | |
label_encoder_type = joblib.load(LABEL_ENCODER_TYPE_PATH) | |
# Load the model | |
model = MultiOutputModel(num_classes_type=len(label_encoder_type.classes_)) | |
model.load_state_dict(torch.load(MODEL_PATH)) | |
model.eval() # Set the model to evaluation mode | |
# Define the image transformations | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), # Resize to match the input size of the model | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize using ImageNet stats | |
]) | |
def preprocess_image(image: Image.Image): | |
"""Preprocess the input image.""" | |
return transform(image).unsqueeze(0) # Add batch dimension | |
def run_inference(image: Image.Image): | |
"""Run inference on a single image.""" | |
# Preprocess the image | |
input_tensor = preprocess_image(image) | |
# Move to device (GPU if available) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
input_tensor = input_tensor.to(device) | |
# Forward pass | |
with torch.no_grad(): | |
type_output, freshness_output = model(input_tensor) | |
# Apply sigmoid to freshness output (to get probability) | |
freshness_output = torch.sigmoid(freshness_output.squeeze()) # Sigmoid for binary classification | |
# Decode predictions | |
_, predicted_type_idx = torch.max(type_output, 1) | |
predicted_type = label_encoder_type.inverse_transform([predicted_type_idx.item()]) | |
output = {} | |
output['type'] = str(predicted_type[0]) | |
output['freshness'] = str(float(freshness_output.item())) | |
json_output = json.dumps(output) | |
return str(json_output) | |
# Define Gradio input and output components | |
image_input = gr.Image(type="pil") | |
# Create the Gradio interface | |
demo = gr.Interface( | |
fn=run_inference, | |
inputs=image_input, | |
outputs='text', | |
title="Frshness prediction", | |
description="Upload an image, and the model will detect objects and return the number of objects along with the image showing the bounding boxes." | |
) | |
demo.launch(share=True) | |