import gradio as gr import torch import torch.nn as nn from torchvision import transforms from PIL import Image from torchvision import models def predict(image): print(type(image)) image = Image.fromarray(image.astype('uint8'), 'RGB') # Load model model = models.resnet50(pretrained=True) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 1) model.load_state_dict(torch.load("best_f1.pth", map_location=torch.device('cpu'))) model.eval() # Preprocess image valid_transform = transforms.Compose([ # transforms.ToPILImage(), # Convert the image to a PIL Image transforms.Resize((224, 224)), # Resize the image to final_size x final_size transforms.ToTensor(), # Convert the image to a PyTorch tensor transforms.Normalize( # Normalize the image mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) input_batch = valid_transform(image).unsqueeze(0) # Make prediction with torch.no_grad(): output = model(input_batch) output = torch.sigmoid(output).squeeze().item() if output > 0.5: predicted = 1 else: predicted = 0 int2label = {0: "cat", 1: "dog"} return int2label[predicted] demo = gr.Interface( predict, inputs="image", outputs="label", title="Cats vs Dogs", description="This model predicts whether an image contains a cat or a dog.", examples = ["assets/7.jpg", "assets/44.jpg", "assets/82.jpg", "assets/83.jpg"] ) demo.launch()