aanshbasu's picture
test
9e30d59
raw
history blame
2.45 kB
import torch
import joblib
from torchvision import transforms
from PIL import Image
from models.freshness_model import MultiOutputModel
import torch.nn.functional as F # For applying sigmoid
import gradio as gr
import json
# Define the paths
MODEL_PATH = 'models/multi_output_model.pth'
LABEL_ENCODER_TYPE_PATH = 'models/label_encoder_type.pkl'
# Load the label encoder
label_encoder_type = joblib.load(LABEL_ENCODER_TYPE_PATH)
# Load the model
model = MultiOutputModel(num_classes_type=len(label_encoder_type.classes_))
model.load_state_dict(torch.load(MODEL_PATH))
model.eval() # Set the model to evaluation mode
# Define the image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to match the input size of the model
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize using ImageNet stats
])
def preprocess_image(image: Image.Image):
"""Preprocess the input image."""
return transform(image).unsqueeze(0) # Add batch dimension
def run_inference(image: Image.Image):
"""Run inference on a single image."""
# Preprocess the image
input_tensor = preprocess_image(image)
# Move to device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_tensor = input_tensor.to(device)
# Forward pass
with torch.no_grad():
type_output, freshness_output = model(input_tensor)
# Apply sigmoid to freshness output (to get probability)
freshness_output = torch.sigmoid(freshness_output.squeeze()) # Sigmoid for binary classification
# Decode predictions
_, predicted_type_idx = torch.max(type_output, 1)
predicted_type = label_encoder_type.inverse_transform([predicted_type_idx.item()])
output = {}
output['type'] = str(predicted_type[0])
output['freshness'] = str(float(freshness_output.item()))
json_output = json.dumps(output)
return str(json_output)
# Define Gradio input and output components
image_input = gr.Image(type="pil")
# Create the Gradio interface
demo = gr.Interface(
fn=run_inference,
inputs=image_input,
outputs='text',
title="Frshness prediction",
description="Upload an image, and the model will detect objects and return the number of objects along with the image showing the bounding boxes."
)
demo.launch(share=True)