File size: 1,156 Bytes
1e98bba
 
ccd736a
1e98bba
 
ccd736a
1e98bba
 
 
 
 
 
64d20d7
 
 
 
 
 
1e98bba
64d20d7
1e98bba
ccd736a
1e98bba
 
 
64d20d7
1e98bba
 
6c9d55c
1e98bba
 
 
64d20d7
 
1e98bba
 
 
 
 
 
 
 
 
b54e32f
1e98bba
 
64d20d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torchvision.transforms as transforms
import gradio as gr
from PIL import Image
from model import SimpleCNN

def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),  
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = Image.fromarray(image)
    
    # Ensure that the image has three channels (RGB)
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    image = transform(image)
    image = image.unsqueeze(0)
    return image

def predict_image(model, image):
    with torch.no_grad():
        output = model(image)
   
    _, predicted = torch.max(output.data, 1)
    return predicted.item()

def main():
    model = SimpleCNN()
    model.load_state_dict(torch.load('cifar10_model.pth'))
    
    # Set the model to evaluation mode
    model.eval()

    iface = gr.Interface(
        fn=lambda img: predict_image(model, preprocess_image(img)),
        inputs=gr.Image(),  
        outputs="label",
        live=True,
    )
    
    iface.launch()

if __name__ == "__main__":
    main()