Spaces:
Sleeping
Sleeping
File size: 1,625 Bytes
6255c29 f702a0a 6255c29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from torchvision import models
def predict(image):
print(type(image))
image = Image.fromarray(image.astype('uint8'), 'RGB')
# Load model
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)
model.load_state_dict(torch.load("best_f1.pth", map_location=torch.device('cpu')))
model.eval()
# Preprocess image
valid_transform = transforms.Compose([
# transforms.ToPILImage(), # Convert the image to a PIL Image
transforms.Resize((224, 224)), # Resize the image to final_size x final_size
transforms.ToTensor(), # Convert the image to a PyTorch tensor
transforms.Normalize( # Normalize the image
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
input_batch = valid_transform(image).unsqueeze(0)
# Make prediction
with torch.no_grad():
output = model(input_batch)
output = torch.sigmoid(output).squeeze().item()
if output > 0.5:
predicted = 1
else:
predicted = 0
int2label = {0: "cat", 1: "dog"}
return int2label[predicted]
demo = gr.Interface(
predict,
inputs="image",
outputs="label",
title="Cats vs Dogs",
description="This model predicts whether an image contains a cat or a dog.",
examples = ["assets/7.jpg", "assets/44.jpg", "assets/82.jpg", "assets/83.jpg"]
)
demo.launch() |