File size: 1,435 Bytes
b8335da
084852b
 
 
 
 
b8335da
084852b
 
b8335da
084852b
 
 
 
 
 
1edbe9b
084852b
 
 
 
 
b8335da
084852b
 
 
 
 
 
 
b8335da
084852b
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
import torchvision.models as models
import torch.nn as nn

# Define the class names
class_names = ['911', 'cayenne', 'cayman', 'macan', 'panamera', 'taycan']

# Instantiate the model and load state_dict
model_ft = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
for param in model_ft.parameters():
    param.requires_grad = False
for param in model_ft.layer4.parameters():
    param.requires_grad = True

num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, len(class_names))
model_ft = model_ft.to('cuda' if torch.cuda.is_available() else 'cpu')
model_ft.load_state_dict(torch.load('model_ft.pth'))
model_ft.eval()

# Define preprocessing transforms
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Define the prediction function
def predict(image):
    image = preprocess(image).unsqueeze(0).to(model_ft.device)  # Add batch dimension and move to device
    with torch.no_grad():
        outputs = model_ft(image)
        _, predicted = torch.max(outputs, 1)
    return class_names[predicted.item()]

# Create Gradio interface
iface = gr.Interface(fn=predict,
                     inputs=gr.inputs.Image(type="pil"),
                     outputs="text")

iface.launch()