osbm's picture
Update app.py
f702a0a verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from torchvision import models
def predict(image):
print(type(image))
image = Image.fromarray(image.astype('uint8'), 'RGB')
# Load model
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)
model.load_state_dict(torch.load("best_f1.pth", map_location=torch.device('cpu')))
model.eval()
# Preprocess image
valid_transform = transforms.Compose([
# transforms.ToPILImage(), # Convert the image to a PIL Image
transforms.Resize((224, 224)), # Resize the image to final_size x final_size
transforms.ToTensor(), # Convert the image to a PyTorch tensor
transforms.Normalize( # Normalize the image
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
input_batch = valid_transform(image).unsqueeze(0)
# Make prediction
with torch.no_grad():
output = model(input_batch)
output = torch.sigmoid(output).squeeze().item()
if output > 0.5:
predicted = 1
else:
predicted = 0
int2label = {0: "cat", 1: "dog"}
return int2label[predicted]
demo = gr.Interface(
predict,
inputs="image",
outputs="label",
title="Cats vs Dogs",
description="This model predicts whether an image contains a cat or a dog.",
examples = ["assets/7.jpg", "assets/44.jpg", "assets/82.jpg", "assets/83.jpg"]
)
demo.launch()