matikosowy's picture
torch model
084852b
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
import torchvision.models as models
import torch.nn as nn
# Define the class names
class_names = ['911', 'cayenne', 'cayman', 'macan', 'panamera', 'taycan']
# Instantiate the model and load state_dict
model_ft = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
for param in model_ft.parameters():
param.requires_grad = False
for param in model_ft.layer4.parameters():
param.requires_grad = True
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, len(class_names))
model_ft = model_ft.to('cuda' if torch.cuda.is_available() else 'cpu')
model_ft.load_state_dict(torch.load('model_ft.pth'))
model_ft.eval()
# Define preprocessing transforms
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Define the prediction function
def predict(image):
image = preprocess(image).unsqueeze(0).to(model_ft.device) # Add batch dimension and move to device
with torch.no_grad():
outputs = model_ft(image)
_, predicted = torch.max(outputs, 1)
return class_names[predicted.item()]
# Create Gradio interface
iface = gr.Interface(fn=predict,
inputs=gr.inputs.Image(type="pil"),
outputs="text")
iface.launch()